In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as cp

from sklearn import metrics, svm
from sklearn.preprocessing import LabelBinarizer, LabelEncoder, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import pickle as pl

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import os
import copy
import sys



In [4]:
torch.set_num_threads(10)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)

# Variables set based off paper
d_ff = 1024
dropout_rate = 0.3
n_epochs = 50
batch_size = 16
n_head = 5
learning_rate = 0.0001
gain = 1

cuda


In [5]:
rand_seed = 65
n_gene = 1708
n_feature = 1708
n_class = 34
query_gene = 64
val = True

In [6]:
# Building multi attention model
class MultiAttention(torch.nn.Module):
    def __init__(self, batch_size, n_head, n_gene, n_feature, query_gene, mode):
        super(MultiAttention, self).__init__()
        self.batch_size = batch_size
        self.n_head = n_head
        self.n_gene = n_gene
        self.n_feature = n_feature
        self.query_gene = query_gene
        
        # Create query, key and value matricies
        
        self.Wq = nn.Parameter(torch.Tensor(self.n_head, n_feature, 1), requires_grad = True)
        self.Wk = nn.Parameter(torch.Tensor(self.n_head, n_feature, 1), requires_grad = True)
        self.Wv = nn.Parameter(torch.Tensor(self.n_head, n_feature, 1), requires_grad = True)
        # Initialize weights
        torch.nn.init.xavier_normal(self.Wq, gain=1)
        torch.nn.init.xavier_normal(self.Wk, gain=1)
        torch.nn.init.xavier_normal(self.Wv, gain=1)
        
        self.W0 = nn.Parameter(torch.Tensor(self.n_head*[0.001]), requires_grad = True)
    
    # finding the query to keys difference
    def QK_difference(self, Q_seq, K_seq):
        QK_diff = torch.pow((Q_seq - K_seq), 2) * -1
        return torch.nn.Softmax(dim=2)(QK_diff)
    
    # Applying mask softmax
    def mask_softmax(self,x):
        d = x.shape[1]
        x = x * ((1 - torch.eye(d, d)).to(device))
        return x
    # Attention
    def attention(self, x, Q_seq, Wk, Wv):
        K_seq = x * Wk
        K_seq = K_seq.expand(K_seq.shape[0], K_seq.shape[1], self.n_gene)
        K_seq = K_seq.permute(0, 2, 1)
        V_seq = x * Wv
        QK_product = Q_seq * K_seq
        z = torch.nn.Softmax(dim=2)(QK_product)

        z = self.mask_softmax(z)
        output_seq = torch.matmul(z, V_seq)

#         elif self.mode == 1:
#             zz_list = []
#             for q in range(self.n_gene // self.query_gene):
#                 K_seq = x * Wk
#                 V_seq = x * Wv
#                 Q_seq = x[:, (q * self.query_gene):((q + 1) * self.query_gene), :]
#                 Q_seq = Q_seq_x.exapnd(Q_seq_x.shape[0], Q_seq_x.shape[1], self.n_gene)
#                 K_seq = K_seq.expand(K_seq.shape[0], K_seq.shape[1], self.query_gene)
#                 K_seq = K_seq.permute(0, 2, 1)
                
#                 QK_diff = self.QK_difference(Q_seq, K_seq)
#                 z = torch.nn.Softmax(dim=2)(QK_diff)
#                 z = torch.matmul(z, V_seq)
#                 zz_list.append(z)
#             out_seq = torch.cat(zz_list, dim=1)
        return output_seq
    
    # Forward prop
    def forward(self, x):
        
        x = torch.reshape(x, (x.shape[0], x.shape[1], 1))
        output_h = []
        for h in range(self.n_head):
            Q_seq = x * self.Wq[h, :, :]
            Q_seq = Q_seq.expand(Q_seq.shape[0], Q_seq.shape[1], self.n_gene)
            
            attention_output = self.attention(x, Q_seq, self.Wk[h,:,:], self.Wv[h,:,:])
            
            output_h.append(attention_output)
        output_seq = torch.cat(output_h, dim=2)
        output_seq = torch.matmul(output_seq, self.W0)
        return output_seq

In [7]:
# Layer normalization
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps
    
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

# Add residual connection followed by layer normalization
class ResidualConnect(nn.Module):
    def __init__(self, size, dropout):
        super(ResidualConnect, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, out):
        return x + self.norm(self.dropout(out))

In [8]:
# Building Network
class Net(nn.Module):
    def __init__(self, batch_size, n_head, n_gene, n_feature, n_class, query_gene, dropout_rate, d_ff, mode):
        super(Net, self).__init__()
        self.batch_size = batch_size
        self.n_head = n_head
        self.n_gene = n_gene
        self.n_feature = n_feature
        self.n_class = n_class
        self.query_gene = query_gene
        self.dropout_rate = dropout_rate
        self.d_ff = d_ff
        
        # multi attention layers
        self.multi_attn1 = MultiAttention(self.batch_size, self.n_head, self.n_gene, self.n_feature, query_gene, mode)
        self.multi_attn2 = MultiAttention(self.batch_size, self.n_head, self.n_gene, self.n_feature, query_gene, mode)
        self.multi_attn3 = MultiAttention(self.batch_size, self.n_head, self.n_gene, self.n_feature, query_gene, mode)
        
        # fully connect layer
        self.fc = nn.Linear(self.n_gene, self.n_class)
        
        # uniform distribution
        torch.nn.init.xavier_uniform_(self.fc.weight, gain=1)
        self.ffn1=nn.Linear(1708, 1024)
        self.ffn2 = nn.Linear(1024,1708)

        self.dropout = nn.Dropout(0.3)
        self.sublayer = ResidualConnect(n_gene, 0.3)
        
    def feed_forward(self, x):
        output = F.relu(self.ffn1(x))
        output = self.ffn2(self.dropout(output))
        return output
    
    def forward(self, x):
        
        output_attn = self.multi_attn1(x)
        output_attn1 = self.sublayer(x, output_attn)
        output_attn2 = self.multi_attn2(output_attn1)
        output_attn2 = self.sublayer(output_attn1, output_attn2)
        output_attn3 = self.multi_attn3(output_attn2)
        output_attn3 = self.sublayer(output_attn2, output_attn3)
        
        # leaky relu activation
        m = torch.nn.LeakyReLU(0.1)
        output_attn3 = m(output_attn3)
        
        # relu activation function
        # output_attn3 = F.relu(output_attn3)
        
        y_pred = self.fc(output_attn3)
        y_pred = F.log_softmax(y_pred, dim=1)
        
        return y_pred
        

In [9]:
y, data_df, pathway_gene, pathway, cancer_name = pl.load(open('/kaggle/input/tgem-data/pathway_data.pckl', 'rb'))
data = np.array(data_df)
x = np.float32(data)
gene_list = data_df.columns.tolist()

x = np.float32(data)
encoder = LabelEncoder()
y_label = encoder.fit_transform(y)
class_label = np.unique(y)

u, count = np.unique(y_label, return_counts=True)
count_sort_ind = np.argsort(-count)
# top 34 pieces of data
y_label_unique34 = u[count_sort_ind[0:34]]

x_top34 = []
y_top34 = []
sample_size = []

for j, sample_label in enumerate(y_label_unique34):
    sample_index = np.argwhere(y_label == sample_label)[:,0]
    sample_size.append(sample_index.shape)
    x_top34.append(x[sample_index])
    temp_y = y_label[sample_index]
    temp_y[temp_y == sample_label] = j
    y_top34.append(temp_y)
    

In [10]:
X_train = []
X_test = []
y_train = []
y_test = []


for d in range(len(x_top34)):
    x_train, x_test, Y_train, Y_test = train_test_split(x_top34[d], y_top34[d], test_size=0.2, random_state = rand_seed)
    
    X_train.append(x_train)
    X_test.append(x_test)
    y_train.append(Y_train)
    y_test.append(Y_test)
    
if val == True:
    X_train_validation = []
    X_validation = []
    y_train_validation = []
    y_validation = []
    
    for d in range(len(x_top34)):
        x_train_val, x_val, y_train_val, y_val = train_test_split(X_train[d], y_train[d], test_size = 0.1, random_state=rand_seed)
        X_train_validation.append(x_train_val)
        X_validation.append(x_val)
        y_train_validation.append(y_train_val)
        y_validation.append(y_val)
        
    X_train = X_train_validation
    y_train = y_train_validation
    X_val_input = torch.from_numpy(np.vstack(X_validation))
    y_val_input = torch.from_numpy(np.hstack(y_validation))

X_train_input = torch.from_numpy(np.vstack(X_train))
X_test_input = torch.from_numpy(np.vstack(X_test))
y_train_input = torch.from_numpy(np.hstack(y_train))
y_test_input = torch.from_numpy(np.hstack(y_test))

In [11]:
model = Net(batch_size, n_head, n_gene, n_feature, n_class, query_gene, d_ff, dropout_rate, mode=0).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas = (0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False)

  app.launch_new_instance()


In [13]:
train_loss_list = []
val_loss_list = []
results = {}
confusion_matrix_result = []
mcc_result = []
acc_result = []
auc_result = []
f1_result = []

for epoch in range(50):
    print(epoch)
    train_loss = 0
    permutation = torch.randperm(X_train_input.size()[0])
    n_correct, n_total = 0, 0
    for batch_idx, i in enumerate(range(0, X_train_input.size()[0], batch_size)):
        model.train()
        indices = permutation[i:i + batch_size]
        batch_x, batch_y = X_train_input[indices], y_train_input[indices]
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        
        optimizer.zero_grad()
        
        y_pred = model(batch_x.float())
        loss = F.nll_loss(y_pred, batch_y)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(batch_x), len(X_train_input),
                           100. * i / len(X_train_input), loss.item()))
    train_loss /= len(X_train_input)
    train_loss_list.append(train_loss)
    
    if val == True:
        model.eval()
        permutation_val = torch.randperm(X_val_input.size()[0])
        correct_val = 0
        val_loss = 0
        with torch.no_grad():
            batch_prediction = []
            batch_y_validation_list = []
            batch_prediction_category = []
            for batch_idx_val, i in enumerate(range(0, X_val_input.size()[0], batch_size)):
                indices_validation = permutation_val[i:i + batch_size]
                batch_x_validation, batch_y_validation = X_val_input[indices_validation], y_val_input[indices_validation]
                batch_x_validation, batch_y_validation = batch_x_validation.to(device), batch_y_validation.to(device)
                
                output_validation = model(batch_x_validation.float())
                val_loss += F.nll_loss(output_validation, batch_y_validation, reduction = "sum")
                pred_validation = output_validation.argmax(dim=1, keepdim=True)
                
                correct_val += pred_validation.eq(batch_y_validation.view_as(pred_validation)).sum().item()
                batch_prediction.append(pred_validation.cpu().data.numpy())
                batch_y_validation_list.append(batch_y_validation.cpu().data.numpy())
                batch_prediction_category.append(output_validation.cpu().data.numpy())
                
            val_loss /= len(X_val_input)
            val_loss_list.append(val_loss.item())
            
            print('\nValidation Set: Average Loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
                val_loss, correct_val, len(X_val_input), 
                100.*correct_val / len(X_val_input)))
            
            yy_validation = np.hstack(batch_y_validation_list).reshape(-1, 1)
            ppred_classes = np.vstack(batch_prediction)
            
            acc_validation = accuracy_score(yy_validation, ppred_classes)
            f1 = f1_score(yy_validation, ppred_classes, average = "micro")
            
            confusion_matrix = metrics.confusion_matrix(yy_validation, ppred_classes)
            mcc = metrics.matthews_corrcoef(yy_validation, ppred_classes)
            
            encoder_ = LabelBinarizer()
            yy_val = encoder_.fit_transform(yy_validation)
            roc_auc = metrics.roc_auc_score(yy_val, np.exp(np.vstack(batch_prediction_category)), multi_class='ovr', average="micro")
            
            confusion_matrix_result.append(confusion_matrix)
            mcc_result.append(mcc)
            acc_result.append(acc_validation)
            auc_result.append(roc_auc)
            f1_result.append(f1)
            
            torch.save(model, "/kaggle/working/transformer.model")
            

0

Validation Set: Average Loss: 2.0054, Accuracy: 530/900 (58.89%)

1

Validation Set: Average Loss: 1.5811, Accuracy: 681/900 (75.67%)

2

Validation Set: Average Loss: 0.7643, Accuracy: 728/900 (80.89%)

3

Validation Set: Average Loss: 0.4405, Accuracy: 782/900 (86.89%)

4

Validation Set: Average Loss: 0.3875, Accuracy: 793/900 (88.11%)

5

Validation Set: Average Loss: 0.3616, Accuracy: 798/900 (88.67%)

6

Validation Set: Average Loss: 0.3551, Accuracy: 800/900 (88.89%)

7

Validation Set: Average Loss: 0.3375, Accuracy: 804/900 (89.33%)

8

Validation Set: Average Loss: 0.3346, Accuracy: 812/900 (90.22%)

9

Validation Set: Average Loss: 0.3169, Accuracy: 813/900 (90.33%)

10

Validation Set: Average Loss: 0.3073, Accuracy: 823/900 (91.44%)

11

Validation Set: Average Loss: 0.2908, Accuracy: 826/900 (91.78%)

12

Validation Set: Average Loss: 0.2945, Accuracy: 822/900 (91.33%)

13

Validation Set: Average Loss: 0.2902, Accuracy: 823/900 (91.44%)

14

Validation Set: Average Lo

In [None]:
results['confusion matrix'] = confusion_matrix_result
results['mcc'] = mcc_result
results['f1'] = f1_result
results['acc'] = acc_result
results['auc'] = auc_result


pl.dump(results, open("/kaggle/working/results.dat", 'wb'))
plt.plot(train_loss_list, label='Training Loss')
plt.plot(val_loss_list , label="Validation Loss")
plt.legend()
plt.savefig('plot.png', format='png')
plt.close()