In [6]:
import os
import re

import matplotlib
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pecanpy as pp
import seaborn as sns
import sklearn.metrics
import torch
import torch.nn as nn

# Style
sns.set_theme(context='talk', style='white', palette='Set2')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42


# Preparation

In [7]:
# Get files for contrast
# contrast = 'c02x'
dir = '../data/scenic_outs/'
# fnames = [fname for fname in os.listdir(dir) if fname.startswith(f'{contrast}_')]
fnames = os.listdir(dir)

# Get groups and cell-type based on fname
# TODO: Doesn't work for groups like 'AD_resilient'
m1 = re.compile('^([cr]\d+x(_\d+)?)_([a-zA-Z]+(_(resilient|strict))?)_((?!resilient|strict)\w+)_regulon_list.csv$')
get_contrast = lambda fname: m1.match(fname).group(1)
get_group = lambda fname: m1.match(fname).group(3)
get_cell_type = lambda fname: m1.match(fname).group(6)

# Get labels
gene_dir = '../data/new_labels/'
gene_fnames = [fname for fname in os.listdir(gene_dir) if fname.endswith('.txt')]
gene_lists = {'.'.join(fname.split('.')[:-1]): np.loadtxt(os.path.join(gene_dir, fname), dtype=str) for fname in gene_fnames}

# Special cases
gene_lists['scz'] = gene_lists['SCZ']
gene_lists['AD_resilient'] = gene_lists['AD_strict'] = gene_lists['AD']


# Analyses

### TF Dot Analysis

In [8]:
for fname in fnames:
    ### Reading
    # Choose graph
    contrast = get_contrast(fname)
    group = get_group(fname)
    cell_type = get_cell_type(fname)
    print(' - '.join([fname, group, cell_type]))

    # Escape if not compatible
    if group not in gene_lists:
        print()
        continue

    # Get TF-TG linkages
    graph_list = pd.read_csv(os.path.join(dir, fname), index_col=0)
    graph_list = graph_list.rename(columns={'gene': 'TG', 'CoexWeight': 'coex'})

    # Get matrix
    graph_matrix = graph_list.pivot(index='TF', columns='TG', values='coex').fillna(0)

    # Get tf matrix
    # tf_matrix = pd.DataFrame(graph_matrix.to_numpy() @ graph_matrix.to_numpy().T, index=graph_matrix.index, columns=graph_matrix.index)  # TF
    tf_matrix = pd.DataFrame(graph_matrix.to_numpy().T @ graph_matrix.to_numpy(), index=graph_matrix.columns, columns=graph_matrix.columns)  # TG

    # Annotate
    gene_list = gene_lists[group]
    annotation = tf_matrix.index.map(lambda g: g in gene_list).to_numpy()

    # Sort based on annotation
    # annotation_sort_idx = list(annotation.argsort())[::-1]
    # tf_matrix = tf_matrix.iloc[annotation_sort_idx, annotation_sort_idx]
    # annotation = annotation[annotation_sort_idx]

    ### Processing
    # Sort genes based on dot with known TFs
    score = (tf_matrix.to_numpy() - np.diag(tf_matrix.to_numpy().diagonal()))[:, annotation].sum(axis=1)
    score[annotation] *= float(len(annotation)) / (len(annotation) - 1)
    sorted_idx = score.argsort()[::-1]

    ### Analysis
    # Evaluate performance
    if annotation[sorted_idx].sum() != 0:
        average_positive_percentile = np.linspace(1, 0, num=tf_matrix.shape[0])[annotation[sorted_idx]].mean()
    else:
        average_positive_percentile = np.nan
    print(f'Average positive percentile of {average_positive_percentile:.3f}')
    positive_unknown_genes = tf_matrix.index.to_numpy()[(score >= np.percentile(score, 100 - 100*float(annotation.sum()) / score.shape[0])) * ~annotation]
    print(f'Positive unknown genes: {positive_unknown_genes}')
    negative_positive_genes = tf_matrix.index.to_numpy()[(score < np.percentile(score, 100 - 100*float(annotation.sum()) / score.shape[0])) * annotation]
    print(f'Negative positive genes: {negative_positive_genes}')
    df = pd.DataFrame({
        # 'TF': tf_matrix.index.to_numpy()[sorted_idx],  # TF
        'TG': tf_matrix.index.to_numpy()[sorted_idx],  # TG
        'score': score[sorted_idx],
        'percentile': np.linspace(1, 0, num=tf_matrix.shape[0]),
        'annotation': annotation[sorted_idx],
    })

    ### Visualization
    # Sort based on score
    score_sort_idx = list(score.argsort())[::-1]
    tf_matrix_sort_score = tf_matrix.iloc[score_sort_idx, score_sort_idx]
    annotation_sort_score = annotation[score_sort_idx]
    # Plot
    fig, ax = plt.subplots(1, 1, figsize=(9, 9))
    plt.title(f'{group}_{cell_type} - APP {average_positive_percentile:.3f}')
    # tf_matrix_no_diag = tf_matrix.copy()
    # for i in range(tf_matrix.shape[0]): tf_matrix_no_diag.iloc[i, i] = 0
    sns.heatmap(tf_matrix_sort_score, norm=LogNorm(), cmap='mako_r', ax=ax)
    for i in np.argwhere(annotation_sort_score): plt.axhline(y=i, color='red', linewidth=1)
    fig.savefig(f'./plots/tf_matrix_{contrast}_{group}_{cell_type}.pdf', format='pdf', transparent=True)
    plt.tight_layout()
    plt.close()

    print()


c02x_AD_Astro_regulon_list.csv - AD - Astro
Average positive percentile of 0.469
Positive unknown genes: ['AEBP1' 'ANGPTL4' 'ANOS1' 'APLNR' 'ARID5A' 'ATF3' 'BAG3' 'BARD1' 'BCL6'
 'BHLHE40' 'BTG2' 'C1R' 'C1RL' 'C4orf19' 'CACHD1' 'CDKN1A' 'CEBPB' 'CEBPD'
 'CLIC4' 'CNN3' 'CRISPLD1' 'CRYAB' 'CSRNP1' 'CYCS' 'DBI' 'DDIT4' 'DTNA'
 'DUSP1' 'ELL2' 'EMP1' 'EPAS1' 'ETV6' 'FGF2' 'FOS' 'FOSL2' 'GADD45A'
 'GADD45B' 'GEM' 'GFAP' 'GHR' 'GPCPD1' 'GRIA1' 'HELB' 'HES1' 'HMGCS1'
 'HSP90AA1' 'HSPA1A' 'HSPB1' 'HSPB8' 'HSPD1' 'HSPH1' 'ID2' 'ID3' 'IDI1'
 'IER2' 'IGFBP7' 'IL13RA1' 'IL1R1' 'ITPKB' 'JUN' 'JUNB' 'KIAA0040' 'KLF6'
 'LAMA1' 'LIMK2' 'LMO2' 'LTBP1' 'MAP3K14' 'MATN2' 'MIDN' 'MSN' 'NAMPT'
 'NFIL3' 'NR4A1' 'NRP1' 'NRP2' 'OSMR' 'P2RY6' 'PAM' 'PARP9' 'PDPN'
 'PFKFB2' 'PLSCR1' 'PMP2' 'RASD1' 'RELL1' 'RFX4' 'RGS16' 'RND3' 'RNF122'
 'SAMD4A' 'SBNO2' 'SCARA3' 'SH3GL2' 'SLC38A1' 'SLC38A2' 'SMAD3' 'SOCS3'
 'STOM' 'SULF1' 'TANC1' 'TEAD3' 'TEAD4' 'TIPARP' 'TNFRSF1A' 'TOB1'
 'TRIP10' 'TXNIP' 'UBASH3B' 'UBC' 'WIPF1

### Node2Vec Analysis

In [None]:
class MLP(torch.nn.Module):
    def __init__(self, input_size, dropout=.6):
        super().__init__()
        self.mlp = nn.Sequential(
            # Multiple layer
            # nn.Linear(input_size, input_size//2),
            # nn.Dropout(dropout),
            # nn.BatchNorm1d(input_size//2),
            # nn.LeakyReLU(),

            # nn.Linear(input_size//2, input_size//4),
            # nn.Dropout(dropout),
            # nn.BatchNorm1d(input_size//4),
            # nn.LeakyReLU(),

            # nn.Linear(input_size//4, 2),
            # nn.Dropout(dropout),
            # nn.BatchNorm1d(2),
            # nn.Softmax(1),

            # Single hidden layer
            # nn.Linear(input_size, 64),
            # nn.Dropout(dropout),
            # nn.BatchNorm1d(64),
            # nn.LeakyReLU(),

            # nn.Linear(64, 2),
            # nn.Dropout(dropout),
            # nn.BatchNorm1d(2),
            # nn.Softmax(1),

            # Single layer
            nn.Linear(input_size, 2),
            nn.BatchNorm1d(2),
            nn.Softmax(1),
        )

    def forward(self, X):
        return self.mlp(X)


In [None]:
# Parameters
dim = 64  # Automatic if `features``=='coex'
lr = 1e-1
gamma = .96
features = ('embeddings', 'coex')[0]

replications = 3
folds = 5  # 1 indicates 80-20 validation, 0 indicates no validation
max_epochs = 1_001
batch_size = 64
max_lapses = 20

# Books
graph_type = []
app_list = []
auprc_list = []
auroc_list = []

# Loop
for fname in fnames:
    ### Reading
    # Choose graph
    contrast = get_contrast(fname)
    group = get_group(fname)
    cell_type = get_cell_type(fname)
    print(' - '.join([fname, group, cell_type]))

    # Escape if not compatible
    if group not in gene_lists:
        print()
        continue

    # Get TF-TG linkages
    graph_list = pd.read_csv(os.path.join(dir, fname), index_col=0)
    graph_list = graph_list.rename(columns={'gene': 'TG', 'CoexWeight': 'coex'})

    # Get matrix
    graph_matrix = graph_list.pivot(index='TF', columns='TG', values='coex').fillna(0)
    # Make square
    all_genes = np.unique(list(graph_matrix.index) + list(graph_matrix.columns))
    for gene in all_genes:
        if gene not in graph_matrix.index: graph_matrix.loc[gene] = 0
        if gene not in graph_matrix.columns: graph_matrix[gene] = 0


    ### Processing
    for replication in range(replications):
        print(f'Replication {replication:02d}')
        # Reproducibility
        replication_seed = 42+replication
        np.random.seed(replication_seed)

        # Split into folds
        if folds > 1:
            folds_idx = np.array_split(np.array(range(np.unique(np.concatenate([np.unique(graph_list['TF']), np.unique(graph_list['TG'])])).shape[0])), folds)
        elif folds == 1:
            folds_idx = np.array(list(range(np.unique(np.concatenate([np.unique(graph_list['TF']), np.unique(graph_list['TG'])])).shape[0])))
            np.random.shuffle(folds_idx)
            split_idx = int(.8 * folds_idx.shape[0])
            folds_idx = [folds_idx[:split_idx], folds_idx[split_idx:]]
        else:
            folds = 1
            folds_idx = [np.array(list(range(np.unique(np.concatenate([np.unique(graph_list['TF']), np.unique(graph_list['TG'])])).shape[0])))] * 2

        # Run folds
        for fold in range(folds):
            print(f'Fold {fold:02d}')

            # Reproducibility
            fold_seed = 42+replication*folds+fold
            np.random.seed(fold_seed)
            torch.manual_seed(fold_seed)

            # Testing
            # graph_type.append(f'{contrast}_{group}_{cell_type}_{replication}_{fold}')
            # performance.append(np.random.rand())
            # continue

            # Convenience
            fold_idx = np.concatenate([idx for i, idx in enumerate(folds_idx) if i != fold])
            val_idx = folds_idx[fold]

            # Generate embeddings
            if features == 'embeddings':
                graph_list[['TF', 'TG', 'coex']].to_csv('_elist.edg', sep='\t', header=None, index=None)
                g = pp.pecanpy.SparseOTF(p=1, q=1, workers=8, verbose=False, random_state=fold_seed)
                g.read_edg('_elist.edg', weighted=True, directed=True)  # False if coexpression
                embeddings = g.embed(dim=dim, num_walks=10, walk_length=80, window_size=10, epochs=1)
                labels = np.array(g._node_ids)  # As long as no removing, `_node_idmap` isn't needed
            elif features == 'coex':
                labels = graph_matrix.index.to_numpy()
                embeddings_from = graph_matrix.loc[labels].to_numpy()
                embeddings_to = graph_matrix[labels].to_numpy().T
                embeddings = np.concatenate((embeddings_from, embeddings_to), axis=1)
                dim = embeddings.shape[1]

            # Annotate
            gene_list = gene_lists[group]
            annotation = np.array([g in gene_list for g in labels])

            # Predict relevancy
            X = torch.Tensor(embeddings)
            unique, inverse = np.unique(annotation, return_inverse=True)
            y = np.zeros((annotation.shape[0], unique.shape[0]))
            y[np.arange(y.shape[0]), inverse] = 1
            y = torch.Tensor(y)

            # Make model
            model = MLP(dim)
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=gamma)

            # Train model
            model.train()
            lapses = 0; best_loss = np.inf
            for epoch in range(max_epochs):
                epoch_loss = []
                epoch_val_loss = []
                for batch in range(fold_idx.shape[0] // batch_size):
                    # Sample

                    # Training
                    batch_idx = np.random.choice(fold_idx, batch_size)
                    batch_X = X[batch_idx]
                    batch_y = y[batch_idx]
                    logits = model(batch_X)
                    loss = ((logits - batch_y)**2).mean()
                    epoch_loss.append(loss.detach())

                    # Validation
                    val_X = X[val_idx]
                    val_y = y[val_idx]
                    logits = model(val_X)
                    val_loss = ((logits - val_y)**2).mean()
                    epoch_val_loss.append(val_loss.detach())

                    # Iterate
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    scheduler.step()
                epoch_loss = torch.Tensor(epoch_loss).mean()
                epoch_val_loss = torch.Tensor(epoch_val_loss).mean()

                # Early stopping
                if epoch_val_loss < best_loss:
                    best_loss = epoch_val_loss
                    lapses = 0
                else:
                    lapses += 1
                if epoch % (max_epochs // 5) == 0 or lapses >= max_lapses: print(f'Epoch: {epoch:03d}\tLoss: {epoch_loss:.4f}\tValidation Loss: {epoch_val_loss:.4f}')
                if lapses >= max_lapses:
                    break

            model.eval()

            ### Analysis
            # Processing
            logits = model(X[val_idx]).detach()
            probability = logits[:, 1]
            sorted_idx = torch.argsort(probability).flip(dims=(0,))
            true = annotation[val_idx][sorted_idx]

            # Outcomes
            if true.sum() != 0:
                average_positive_percentile = np.linspace(1, 0, num=val_idx.shape[0])[true].mean()
            else:
                average_positive_percentile = np.nan
            print(f'Average positive percentile of {average_positive_percentile:.3f}')
            if np.unique(true).shape[0] != 1:
                auprc = sklearn.metrics.average_precision_score(true, probability)
            else:
                auprc = np.nan
            print(f'AUPRC of {auprc:.3f}')
            if np.unique(true).shape[0] != 1:
                auroc = sklearn.metrics.roc_auc_score(true, probability)
            else:
                auroc = np.nan
            print(f'AUROC of {auroc:.3f}')

            ### Books
            graph_type.append(f'{contrast}_{group}_{cell_type}_{replication}_{fold}')
            app_list.append(average_positive_percentile)
            auprc_list.append(auprc)
            auroc_list.append(auroc)

            print()
        print()
    print()

### Format
graph_type = np.array(graph_type).reshape((-1, replications, folds))
app_list = np.array(app_list).reshape((-1, replications, folds))
auprc_list = np.array(auprc_list).reshape((-1, replications, folds))
auroc_list = np.array(auroc_list).reshape((-1, replications, folds))


c02x_AD_Astro_regulon_list.csv - AD - Astro
Replication 00
Fold 00
Epoch: 000	Loss: 0.0723	Validation Loss: 0.0801
Epoch: 029	Loss: 0.0392	Validation Loss: 0.0492
Average positive percentile of 0.480
AUPRC of 0.055
AUROC of 0.525

Fold 01
Epoch: 000	Loss: 0.0721	Validation Loss: 0.0718
Epoch: 028	Loss: 0.0361	Validation Loss: 0.0386
Average positive percentile of 0.568
AUPRC of 0.087
AUROC of 0.610

Fold 02
Epoch: 000	Loss: 0.0817	Validation Loss: 0.0966
Epoch: 029	Loss: 0.0321	Validation Loss: 0.0563
Average positive percentile of 0.478
AUPRC of 0.102
AUROC of 0.580

Fold 03
Epoch: 000	Loss: 0.0771	Validation Loss: 0.0732
Epoch: 029	Loss: 0.0435	Validation Loss: 0.0313
Average positive percentile of 0.581
AUPRC of 0.093
AUROC of 0.499

Fold 04
Epoch: 000	Loss: 0.0827	Validation Loss: 0.0710
Epoch: 029	Loss: 0.0398	Validation Loss: 0.0271
Average positive percentile of 0.399
AUPRC of 0.034
AUROC of 0.531


Replication 01
Fold 00
Epoch: 000	Loss: 0.0705	Validation Loss: 0.0775
Epoch: 02

In [None]:
# Format
m2 = re.compile('^([cr]\d+x(_\d+)?)_([a-zA-Z]+(_(resilient|strict))?)_((?!resilient|strict)\w+)_(\d+)_(\d+)$')
df_get_contrast = lambda s: m2.match(s).group(1)
df_get_group = lambda s: m2.match(s).group(3)
df_get_cell_type = lambda s: m2.match(s).group(6)
df_get_replication = lambda s: m2.match(s).group(7)
df_get_fold = lambda s: m2.match(s).group(8)
df = pd.DataFrame({
    'Contrast': np.vectorize(df_get_contrast)(graph_type.flatten()),
    'Group': np.vectorize(df_get_group)(graph_type.flatten()),
    'Cell Type': np.vectorize(df_get_cell_type)(graph_type.flatten()),
    'Replication': np.vectorize(df_get_replication)(graph_type.flatten()),
    'Fold': np.vectorize(df_get_fold)(graph_type.flatten()),
    'Average Positive Percentile': app_list.flatten(),
    'AUPRC': auprc_list.flatten(),
    'AUROC': auroc_list.flatten(),
})
df.to_csv('results.csv')

# Parameters
statistic = 'AUROC'
statistic_minimum = .5

# Collapse folds
df = df.groupby(['Contrast', 'Group', 'Cell Type', 'Replication']).mean().reset_index()

# Collapse all
df_summary = df.groupby(['Contrast', 'Group', 'Cell Type']).mean().reset_index()

# Filter
df_avg = df.groupby(['Contrast', 'Group', 'Cell Type']).mean().reset_index()
df_avg.index = df_avg.apply(lambda r: '_'.join([r['Contrast'], r['Group'], r['Cell Type']]), axis=1)
df.index = df.apply(lambda r: '_'.join([r['Contrast'], r['Group'], r['Cell Type']]), axis=1)
df = df.loc[df_avg.index[df_avg[statistic] > statistic_minimum].to_numpy()]
df = df.reset_index(drop=True)

# Sort
df = df.sort_values('Cell Type')
df_summary = df_summary.sort_values('Cell Type')

# Visualize boxplot
fig, ax = plt.subplots(1, 1, figsize=(18, 9))
sns.boxplot(df, x='Cell Type', y=statistic, hue='Group', ax=ax)
plt.axhline(.5, color='black', linestyle='--')
plt.xticks(rotation=90)
plt.tight_layout()
fig.savefig(f'./plots/grn_performance.pdf', format='pdf', transparent=True)
plt.close()

# Visualize line plot
# fig, ax = plt.subplots(1, 1, figsize=(9, 9))
# sns.lineplot(df_summary.sort_values('Cell Type'), x='Cell Type', y=statistic, hue='Group', ax=ax)
# plt.axhline(.5, color='black', linestyle='--')
# plt.xticks(rotation=90)
# plt.tight_layout()
# fig.savefig(f'./plots/grn_performance.pdf', format='pdf', transparent=True)
# plt.close()
