In [11]:
import numpy as np
import pandas as pd
# import networkx as nx
# import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go
# import igviz as ig
# from node2vec import Node2Vec
# from gensim.models import KeyedVectors
# import seaborn as sns
# import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2


# import src.preprocess as pre
# import src.visualize as vis
# pio.renderers.default = "png"

from src import models, training
from torch_geometric.nn import GAE
import torch
import os
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, average_precision_score

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [43]:
data_name="seqfish"


preprocess_output_path = f"./out/{data_name}/1_preprocessing_output/"
training_output_path = f"./out/{data_name}/2_training_output/"
evaluation_output_path = f"./out/{data_name}/3_evaluation_output/"

(cell_train_data, cell_test_data), genelevel_data = training.create_pyg_data(preprocess_output_path, split=0.9)
data = (cell_train_data, genelevel_data)



def build_clarifyGAE_pytorch(data, hyperparams = None):
    num_cells, num_cellfeatures = data[0].x.shape[0], data[0].x.shape[1]
    num_genes, num_genefeatures = data[1].x.shape[0], data[1].x.shape[1]
    hidden_dim = hyperparams["concat_hidden_dim"] // 2
    num_genespercell = hyperparams["num_genespercell"]

    cellEncoder = models.GraphEncoder(num_cellfeatures, hidden_dim)
    geneEncoder = models.SubgraphEncoder(num_features=num_genefeatures, hidden_dim=hidden_dim, num_vertices = num_cells, num_subvertices = num_genespercell)
    
    multiviewEncoder = models.MultiviewEncoder(SubgraphEncoder = geneEncoder, GraphEncoder = cellEncoder)
    gae = GAE(multiviewEncoder)

    return gae

def metrics(model, z, pos_edge_index, neg_edge_index):

    # torch.manual_seed(0)

    pos_y = z.new_ones(pos_edge_index.size(1))
    neg_y = z.new_zeros(neg_edge_index.size(1))
    y = torch.cat([pos_y, neg_y], dim=0)

    pos_pred = model.decode(z, pos_edge_index, sigmoid=True)
    neg_pred = model.decode(z, neg_edge_index, sigmoid=True)
    pred = torch.cat([pos_pred, neg_pred], dim=0)
    

    y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy()
    
    
    precision, recall, _ = precision_recall_curve(y, pred)
    auprc = auc(recall, precision)
    auroc = roc_auc_score(y, pred)
    ap = average_precision_score(y,pred)

    return auprc,ap, auroc

def evaluate_model(studyname, data_name):

    hyperparameters = {
        "num_genespercell": 45,
        "concat_hidden_dim": 64,
        "optimizer" : "adam",
        "criterion" : torch.nn.BCELoss(),
        "num_epochs": 400
    }

    trained_gae = build_clarifyGAE_pytorch(data, hyperparameters)
    trained_gae.load_state_dict(torch.load(os.path.join(training_output_path,f'{studyname}_trained_gae_model.pth')))
    trained_gae.eval()
    
    # cell_level_encoder = trained_gae.encoder.encoder_c
    # gene_level_encoder = trained_gae.encoder.encoder_g
    z, z_c, z_g, gene_embeddings = trained_gae.encode(cell_test_data.x,data[1].x, cell_test_data.edge_index, data[1].edge_index)
    print(cell_train_data, cell_test_data)

    posmask = cell_test_data.edge_label == 1
    negmask = ~posmask

    
    return metrics(trained_gae, z, cell_test_data.edge_label_index[:, posmask], cell_test_data.edge_label_index[:,negmask])
    
    

def training_metrics(studyname):
    df = pd.read_csv(os.path.join(evaluation_output_path, f"{studyname}_metrics_0.3.csv"))
    print(df[["CLARIFY Test AP","CLARIFY Test AUPRC", "CLARIFY Test ROC" ]].max(axis=0))
    
    
# evaluate_model(studyname, data_name)


0.09999999999999998 training edges | 0.9 testing edges


In [44]:
def add_fake_edges(num_vertices, num_old_edges, fp):
    num_add_edges = int((fp) * num_old_edges)
    
    add_edges = torch.from_numpy(np.random.randint(0,high=num_vertices, size=(2,num_add_edges)))
        
    print(f"Added {fp*100}% false edges")
    # print(f"New edge index dimension: {new_edge_indices.size()}")
    return add_edges
  


old_edge_label = cell_train_data.edge_label.clone()
old_edge_label_index = cell_train_data.edge_label_index.clone()

posmask = cell_train_data.edge_label == 1
newedges = add_fake_edges(cell_train_data.x.size()[0], cell_train_data.edge_label_index[:, posmask].shape[1], 0.1)
cell_train_data.edge_label  = torch.cat([cell_train_data.edge_label , torch.ones(newedges.shape[1])])
cell_train_data.edge_label_index =  torch.cat([cell_train_data.edge_label_index , newedges],dim=1)



cell_train_data.edge_label_index.shape, old_edge_label_index.shape, cell_train_data.edge_label.shape


Added 10.0% false edges


(torch.Size([2, 821]), torch.Size([2, 782]), torch.Size([821]))

In [7]:
# studynames = ["withsplit0.3_withgenefeats_withpenalty_120epochs"]
# false_edge_combinations = [("0.0","0.1"),("0.1","0.0"),("0.1","0.1"),("0.2","0.2")]

# for fp,fn in false_edge_combinations:
#     studynames.append(f"fp{fp}_fn{fn}_withsplit0.3_withgenefeats_withpenalty_120epochs")
    
studynames = ['withsplit0.3_withgenefeats_withpenalty_120epochs',
 'fp0.1_fn0.1_withsplit0.3_withgenefeats_withpenalty_120epochs',
 'fp0.2_fn0.2_withsplit0.3_withgenefeats_withpenalty_120epochs']

In [8]:
evals_thirty = []

for name in studynames:
    evals_thirty.append(evaluate_model(studyname=name, data_name="seqfish"))

Data(x=[1597, 125], edge_index=[2, 782], y=[1597, 1597], edge_label=[782], edge_label_index=[2, 782]) Data(x=[1597, 125], edge_index=[2, 782], y=[1597, 1597], edge_label=[7032], edge_label_index=[2, 7032])
Data(x=[1597, 125], edge_index=[2, 782], y=[1597, 1597], edge_label=[782], edge_label_index=[2, 782]) Data(x=[1597, 125], edge_index=[2, 782], y=[1597, 1597], edge_label=[7032], edge_label_index=[2, 7032])
Data(x=[1597, 125], edge_index=[2, 782], y=[1597, 1597], edge_label=[782], edge_label_index=[2, 782]) Data(x=[1597, 125], edge_index=[2, 782], y=[1597, 1597], edge_label=[7032], edge_label_index=[2, 7032])
Data(x=[1597, 125], edge_index=[2, 782], y=[1597, 1597], edge_label=[782], edge_label_index=[2, 782]) Data(x=[1597, 125], edge_index=[2, 782], y=[1597, 1597], edge_label=[7032], edge_label_index=[2, 7032])


In [9]:
evals_thirty

[(0.7834087923170979, 0.7834949786878911, 0.7953464157351215),
 (0.7827574891767713, 0.7828389258965371, 0.7988982350794225),
 (0.759574005087805, 0.7596667198024474, 0.7715084528260862),
 (0.7829097595709589, 0.7830022986053731, 0.7988867889485544)]

In [20]:
evals_thirty

[(0.9532126126915653, 0.9532425950953799, 0.9591026103973257),
 (0.9480925042124696, 0.9482083254422257, 0.9544046814756142),
 (0.8480183725962273, 0.8482407154209896, 0.8814499149669769)]

In [19]:
studynames = ['withsplit0.3_withgenefeats_withpenalty_120epochs',
 'fp0.1_fn0.1_withsplit0.3_withgenefeats_withpenalty_120epochs',
 'fp0.2_fn0.2_withsplit0.3_withgenefeats_withpenalty_120epochs',
 'fp0.3_fn0.3_withsplit0.3_withgenefeats_withpenalty_120epochs',
 'fp0.5_fn0.5_withsplit0.3_withgenefeats_withpenalty_120epochs']

training_metrics(studynames[-1])

CLARIFY Test AP       0.904813
CLARIFY Test AUPRC    0.904758
CLARIFY Test ROC      0.908586
dtype: float64


In [26]:
cell_train_data

Data(x=[1597, 125], edge_index=[2, 782], y=[1597, 1597], edge_label=[782], edge_label_index=[2, 782])

In [6]:
data_name="seqfish"


preprocess_output_path = f"./out/{data_name}/1_preprocessing_output/"
training_output_path = f"./out/{data_name}/2_training_output/"
evaluation_output_path = f"./out/{data_name}/3_evaluation_output/"
metrics_path = f"./out/{data_name}/3_evaluation_output/false_edge_experiments/"
deeplinc_path = f"./benchmark/deeplinc/{data_name}/false_edge_experiments/"

In [13]:
# fp_rates = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
fp_rates = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
deeplinc_aps = [0.8118579]
aps = [0.9088231] # initial value
for fp in fp_rates[1:]:
    metric_path = f"{metrics_path}fp{fp}_fn0.0_withsplit0.3_withgenefeats_withpenalty_120epochs_metrics_0.3.csv"
    deeplinc_aps.append(np.load(f"{deeplinc_path}fp{fp}fn0.0test_ap_scores_0.3.npy")[-1])
    aps.append(pd.read_csv(metric_path)["CLARIFY Test AP"].iloc[-1])
    
    
    
fp_figure = go.Figure()
fp_figure.add_trace(
    go.Scatter(x=fp_rates,y=aps, marker=dict(color = "#d14078",line=dict(width=2,color='black')),
                name="Clarify",opacity=0.7),
) 
fp_figure.add_trace(
    go.Scatter(x=fp_rates,y=deeplinc_aps, marker=dict(color = "#345c72",line=dict(width=2,color='black')),
                name="DeepLinc",opacity=0.7),
) 
fp_figure.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True, ticks='outside',tickfont=dict( size=17, color='black'))
fp_figure.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True,ticks='outside',tickfont=dict( size=17, color='black'))
fp_figure.update_layout(
    title=f"Clarify and DeepLinc Average Precision over different FP edge rates",
    xaxis_title="False Positive Edge Rate",
    yaxis_title="Average Precision",
    legend_title="Model",
    width=800,
    height=600,
    boxmode='group',
    plot_bgcolor='white'
)
fp_figure.show()

In [18]:
fn_rates = [0.0,0.1, 0.2, 0.3, 0.4, 0.5]
fn_deeplinc_aps = [0.8118579]
fn_aps = [0.9088231] # initial value
for fn in fn_rates[1:]:
    metric_path = f"{metrics_path}fp0.0_fn{fn}_withsplit0.3_withgenefeats_withpenalty_120epochs_metrics_0.3.csv"
    fn_deeplinc_aps.append(np.load(f"{deeplinc_path}fp0.0fn{fn}test_ap_scores_0.3.npy")[-1])
    fn_aps.append(pd.read_csv(metric_path)["CLARIFY Test AP"].iloc[-1])
    
    
    
fn_figure = go.Figure()
fn_figure.add_trace(
    go.Scatter(x=fn_rates,y=fn_aps, marker=dict(color = "#d14078",line=dict(width=2,color='black')),
                name="Clarify",opacity=0.7),
) 
fn_figure.add_trace(
    go.Scatter(x=fn_rates,y=fn_deeplinc_aps, marker=dict(color = "#345c72",line=dict(width=2,color='black')),
                name="DeepLinc",opacity=0.7),
) 
fn_figure.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True, ticks='outside',tickfont=dict( size=17, color='black'))
fn_figure.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True,ticks='outside',tickfont=dict( size=17, color='black'))
fn_figure.update_layout(
    title=f"Clarify and DeepLinc Average Precision over different FN edge rates",
    xaxis_title="False Negative Edge Rate",
    yaxis_title="Average Precision",
    legend_title="Model",
    width=800,
    height=600,
    boxmode='group',
    plot_bgcolor='white'
)
fn_figure.show()