In [1]:
import os
from itertools import product

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import pecanpy as pp
import scipy
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve
from sklearn.neighbors import KNeighborsRegressor
from sklearn.neural_network import MLPClassifier
import torch
import torch.nn as nn

In [2]:
np.random.seed(42)

# Dataset

In [3]:
loaded_data = {}
def get_data(cell_type):
    # Get graph
    data_folder = './data/training_matrices_DGL/'
    if cell_type in loaded_data:
        mat = loaded_data[cell_type]
    else:
        mat = pd.read_csv(data_folder + f'{cell_type}.feat.mat', index_col=0, sep='\t')
        loaded_data[cell_type] = mat
    genes = np.array(mat.columns)
    assert (genes == np.array(mat.index)).all()

    # Get labels
    meta = pd.read_csv(data_folder + 'training_labels.txt', index_col=0, sep='\t')
    labels = meta.values[:, 0]
    labels = np.array([str(meta.loc[g]['label']) if g in meta.index else 'unknown' for g in genes])

    return mat.to_numpy(), labels, genes

# Model

In [4]:
cell_type = 'Astrocyte'

In [5]:
print(f'Loading {cell_type} data...')
# Load data
data, labels, genes = get_data(cell_type)

Loading Astrocyte data...


In [6]:
print('Constructing graph...')
# Construct graph
g = pp.graph.AdjlstGraph()
for i in range(data.shape[0]):
    for j in range(i+1, data.shape[0]):
        weight = float(data[i][j])
        if weight != 0:
            g.add_edge(genes[i], genes[j], weight=weight, directed=False)
g.save('_elist.edg')

Constructing graph...


In [7]:
print('Running Node2Vec+...')
print('\tReading graph...')
# Load as precomp
g = pp.pecanpy.SparseOTF(p=1, q=1, workers=4, verbose=True)
g.read_edg('_elist.edg', weighted=True, directed=False)
# g.preprocess_transition_probs()

print('\tGenerate embeddings...')
# Generate embeddings
dim, num_walks, walk_length = 16, 10, 20
emb = g.embed(dim=dim, num_walks=num_walks, walk_length=walk_length)
np.save(f'embeddings-{emb.shape[1]}-{num_walks}-{walk_length}.npy', emb)

# Chart connected subgraph
surviving_nodes = [np.argwhere(genes==gn).flatten()[0] for gn in g.nodes]

Running Node2Vec+...
	Reading graph...
	Generate embeddings...


  0%|                                                                                        | 0/108020 [00:00…

In [8]:
print('Splitting data...')
# Generate train_idx
train_frac = .8
train_idx = np.random.choice(emb.shape[0], int(train_frac * emb.shape[0]), replace=False)
train_idx = np.intersect1d(train_idx, np.array(list( set(list(range(emb.shape[0]))) - set(list(np.argwhere(labels[surviving_nodes]=='unknown').flatten())) )))
eval_idx = np.array(list(set(list(range(emb.shape[0]))) - set(train_idx)))

Splitting data...


In [9]:
class FCL(nn.Module):
    def __init__(self, input_dim, hidden_dim=None, output_dim=1):
        super().__init__()
        
        self.input_dim = input_dim
        hidden_dim = hidden_dim if hidden_dim is not None else int(input_dim)
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        self.main = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            # nn.Dropout(.8),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            
            nn.Linear(hidden_dim, hidden_dim),
            # nn.Dropout(.8),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            
            nn.Linear(hidden_dim, output_dim),
            nn.BatchNorm1d(output_dim),
            nn.Sigmoid(),
        )
        
    def forward(self, X):
        return self.main(X)

In [10]:
print('Creating classification model...')
# Balanced train
batch_size = 128
max_lapses = 100

mlp = FCL(emb.shape[1])
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-5)
criterion = nn.BCELoss()
best_loss = torch.inf; lapses = 0
for epoch in range(10001):
    epoch_loss = 0
    
    mlp.train()
    for batch in range(emb.shape[0] // batch_size):
        seg_size = int(batch_size / 2)
        ad_batch_idx = np.random.choice(
            np.intersect1d(train_idx, np.argwhere(labels[surviving_nodes]=='AD').squeeze()),
            seg_size,
            replace=True)
        notad_batch_idx = np.random.choice(
            np.intersect1d(train_idx, np.argwhere(labels[surviving_nodes]=='notAD').squeeze()),
            seg_size,
            replace=True)
        batch_idx = np.concatenate([ad_batch_idx, notad_batch_idx])

        true = torch.cat([torch.ones((seg_size, 1)), torch.zeros((seg_size, 1))])
        logits = mlp(torch.Tensor(emb[batch_idx]))
        loss = criterion(logits, true)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach()
    epoch_loss = epoch_loss / (emb.shape[0] // batch_size)
    
    # Evaluation loss
    mlp.eval()
    logits = mlp(torch.Tensor(emb[eval_idx])).detach().squeeze()
    trans = {'AD': 1, 'notAD': 0, 'unknown': 2}
    trans_inv = {v: k for k, v in trans.items()}
    true = torch.Tensor([trans[l] for l in labels[surviving_nodes][eval_idx]])
    eval_loss = criterion(logits[np.argwhere(true!=2).squeeze()], true[np.argwhere(true!=2).squeeze()]).detach()
    
    # CLI
    if epoch % 25 == 0:
        print(f'Epoch {epoch}:\tTrain_Loss {float(epoch_loss):.3f},\tEval_Loss {float(eval_loss):.3f}')
    
    # Early stopping
    lapses += 1
    if eval_loss < best_loss:
        best_loss = eval_loss
        lapses = 0
    if lapses >= max_lapses:
        print(f'Epoch {epoch}\tTrain_Loss {float(epoch_loss):.3f},\tEval_Loss {float(eval_loss):.3f}')
        print('Stopped!')
        break
mlp.eval();

Creating classification model...
Epoch 0:	Train_Loss 0.815,	Eval_Loss 0.799
Epoch 25:	Train_Loss 0.695,	Eval_Loss 0.706
Epoch 50:	Train_Loss 0.631,	Eval_Loss 0.680
Epoch 75:	Train_Loss 0.591,	Eval_Loss 0.676
Epoch 100:	Train_Loss 0.564,	Eval_Loss 0.628
Epoch 125:	Train_Loss 0.544,	Eval_Loss 0.599
Epoch 150:	Train_Loss 0.537,	Eval_Loss 0.597
Epoch 175:	Train_Loss 0.523,	Eval_Loss 0.588
Epoch 200:	Train_Loss 0.518,	Eval_Loss 0.574
Epoch 225:	Train_Loss 0.500,	Eval_Loss 0.564
Epoch 250:	Train_Loss 0.496,	Eval_Loss 0.551
Epoch 275:	Train_Loss 0.487,	Eval_Loss 0.553
Epoch 300:	Train_Loss 0.479,	Eval_Loss 0.563
Epoch 325:	Train_Loss 0.468,	Eval_Loss 0.554
Epoch 350:	Train_Loss 0.465,	Eval_Loss 0.538
Epoch 375:	Train_Loss 0.451,	Eval_Loss 0.527
Epoch 400:	Train_Loss 0.448,	Eval_Loss 0.514
Epoch 425:	Train_Loss 0.435,	Eval_Loss 0.510
Epoch 450:	Train_Loss 0.432,	Eval_Loss 0.518
Epoch 475:	Train_Loss 0.425,	Eval_Loss 0.517
Epoch 493	Train_Loss 0.421,	Eval_Loss 0.531
Stopped!


# Evaluation

In [11]:
print('Evaluating performance...')
print('Train')
## Train
# Perform prediction
logits = mlp(torch.Tensor(emb[train_idx])).detach().squeeze()
trans = {'AD': 1, 'notAD': 0, 'unknown': 2}
trans_inv = {v: k for k, v in trans.items()}
true = torch.Tensor([trans[l] for l in labels[surviving_nodes][train_idx]]).long()

# Get confusion
conf = confusion_matrix(true, 1*(logits > .5))
print('T\P\t' + '\t'.join([trans_inv[i] for i in range(len(trans_inv)-1)]))
for i, row in enumerate(conf):
    print(trans_inv[i] + '\t' + '\t'.join([str(e) for e in row]))

# Other statistics
fpr, tpr, thresholds = roc_curve(true[np.argwhere(true!=2).squeeze()], logits[np.argwhere(true!=2).squeeze()])
print(f'AUROC:\t{auc(fpr, tpr):.4f}')
prec, rec, thresholds = precision_recall_curve(true[np.argwhere(true!=2).squeeze()], logits[np.argwhere(true!=2).squeeze()])
print(f'AUPRC:\t{auc(rec, prec):.4f}')
print()

print('Eval')
## Eval
# Perform prediction
logits = mlp(torch.Tensor(emb[eval_idx])).detach().squeeze()
trans = {'AD': 1, 'notAD': 0, 'unknown': 2}
trans_inv = {v: k for k, v in trans.items()}
true = torch.Tensor([trans[l] for l in labels[surviving_nodes][eval_idx]]).long()

# Get confusion
conf = confusion_matrix(true, 1*(logits > .5))
print('T\P\t' + '\t'.join([trans_inv[i] for i in range(len(trans_inv))]))
for i, row in enumerate(conf):
    print(trans_inv[i] + '\t' + '\t'.join([str(e) for e in row]))

# Other statistics
fpr, tpr, thresholds = roc_curve(true[np.argwhere(true!=2).squeeze()], logits[np.argwhere(true!=2).squeeze()])
print(f'AUROC:\t{auc(fpr, tpr):.4f}')
prec, rec, thresholds = precision_recall_curve(true[np.argwhere(true!=2).squeeze()], logits[np.argwhere(true!=2).squeeze()])
print(f'AUPRC:\t{auc(rec, prec):.4f}')
print()
    
print('Recording predictions...')
# Write predicted AD to file (only unknown genes)
unk_idx = np.argwhere(labels[surviving_nodes][eval_idx]=='unknown').squeeze()
np.savetxt('AD.txt', genes[surviving_nodes][eval_idx][np.intersect1d(unk_idx, np.argwhere(logits > .8).squeeze())], fmt='%s')

Evaluating performance...
Train
T\P	notAD	AD
notAD	1943	469
AD	1	102
AUROC:	0.9478
AUPRC:	0.3194

Eval
T\P	notAD	AD	unknown
notAD	437	157	0
AD	14	4	0
unknown	5905	1770	0
AUROC:	0.4988
AUPRC:	0.0293

Recording predictions...
