In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
#load in training data on each potential synapse
data = pd.read_csv("./train_data.csv")

#load in additional features for each neuron
feature_weights = pd.read_csv("./feature_weights.csv")
morph_embeddings = pd.read_csv("./morph_embeddings.csv")

In [2]:
# join all feature_weight_i columns into a single np.array column
feature_weights["feature_weights"] = (
    feature_weights.filter(regex="feature_weight_")
    .sort_index(axis=1)
    .apply(lambda x: np.array(x), axis=1)
)
# delete the feature_weight_i columns
feature_weights.drop(
    feature_weights.filter(regex="feature_weight_").columns, axis=1, inplace=True
)

# join all morph_embed_i columns into a single np.array column
morph_embeddings["morph_embeddings"] = (
    morph_embeddings.filter(regex="morph_emb_")
    .sort_index(axis=1)
    .apply(lambda x: np.array(x), axis=1)
)
# delete the morph_embed_i columns
morph_embeddings.drop(
    morph_embeddings.filter(regex="morph_emb_").columns, axis=1, inplace=True
)

In [3]:
data = (
    data.merge(
        feature_weights.rename(columns=lambda x: "pre_" + x), 
        how="left", 
        validate="m:1",
        copy=False,
    )
    .merge(
        feature_weights.rename(columns=lambda x: "post_" + x),
        how="left",
        validate="m:1",
        copy=False,
    )
    .merge(
        morph_embeddings.rename(columns=lambda x: "pre_" + x),
        how="left",
        validate="m:1",
        copy=False,
    )
    .merge(
        morph_embeddings.rename(columns=lambda x: "post_" + x),
        how="left",
        validate="m:1",
        copy=False,
    )
)

In [4]:
#cosine similarity function
def row_feature_similarity(row):
    pre = row["pre_feature_weights"]
    post = row["post_feature_weights"]
    return (pre * post).sum() / (np.linalg.norm(pre) * np.linalg.norm(post))

In [5]:
# compute the cosine similarity between the pre- and post- feature weights
data["fw_similarity"] = data.apply(row_feature_similarity, axis=1)

In [6]:
# generate projection group as pre->post
data["projection_group"] = (
    data["pre_brain_area"].astype(str)
    + "->"
    + data["post_brain_area"].astype(str)
)

In [20]:
data.columns

Index(['ID', 'axonal_coor_x', 'axonal_coor_y', 'axonal_coor_z',
       'dendritic_coor_x', 'dendritic_coor_y', 'dendritic_coor_z', 'adp_dist',
       'post_skeletal_distance_to_soma', 'pre_skeletal_distance_to_soma',
       'pre_oracle', 'pre_test_score', 'pre_rf_x', 'pre_rf_y', 'post_oracle',
       'post_test_score', 'post_rf_x', 'post_rf_y', 'compartment',
       'pre_brain_area', 'post_brain_area', 'pre_nucleus_x', 'pre_nucleus_y',
       'pre_nucleus_z', 'post_nucleus_x', 'post_nucleus_y', 'post_nucleus_z',
       'pre_nucleus_id', 'post_nucleus_id', 'connected', 'pre_feature_weights',
       'post_feature_weights', 'pre_morph_embeddings', 'post_morph_embeddings',
       'fw_similarity', 'projection_group'],
      dtype='object')

In [7]:
data['fw_post_interaction'] = data['fw_similarity'] * data['post_test_score']
data['fw_pre_interaction'] = data['fw_similarity'] * data['pre_test_score']

In [22]:
data.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 185832 entries, 0 to 185831
Data columns (total 38 columns):
 #   Column                          Non-Null Count   Dtype  
---  ------                          --------------   -----  
 0   ID                              185832 non-null  int64  
 1   axonal_coor_x                   185832 non-null  int64  
 2   axonal_coor_y                   185832 non-null  int64  
 3   axonal_coor_z                   185832 non-null  int64  
 4   dendritic_coor_x                185832 non-null  int64  
 5   dendritic_coor_y                185832 non-null  int64  
 6   dendritic_coor_z                185832 non-null  int64  
 7   adp_dist                        185832 non-null  float64
 8   post_skeletal_distance_to_soma  185832 non-null  float64
 9   pre_skeletal_distance_to_soma   185832 non-null  float64
 10  pre_oracle                      185832 non-null  float64
 11  pre_test_score                  185832 non-null  float64
 12  pre_rf_x        

In [8]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score, accuracy_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from imblearn.over_sampling import SMOTE

from imblearn.pipeline import Pipeline as ImbPipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder

numeric_cols = ["fw_similarity", "adp_dist", "pre_oracle", "post_oracle", "fw_pre_interaction", "fw_post_interaction", "pre_test_score", "post_test_score", "post_skeletal_distance_to_soma", "pre_skeletal_distance_to_soma"]
cat_cols = ['projection_group', "compartment"]
all_cols = numeric_cols + cat_cols

preprocessor = ColumnTransformer(
    transformers=[
        ('num', StandardScaler(), numeric_cols),
        ('cat', OneHotEncoder(), cat_cols)
    ],
    remainder='passthrough' 
)


pipe = ImbPipeline([
    ('preprocessing', preprocessor), 
    ('sampling', SMOTE(random_state=2)),
    ('model', LogisticRegression(random_state=2, max_iter=300))
])

param_grid = {
    'model__C': [0.00001, 0.0001, 0.001, 0.01, 0.1, 1, 5, 10],
    'model__penalty': ['l1', 'l2'],
    'model__solver': ['liblinear', 'saga'],
    'model__class_weight': [None, 'balanced']
}


train_data, test_data = train_test_split(data, test_size=0.2, random_state=1)

grid_search = GridSearchCV(pipe, param_grid, scoring='balanced_accuracy', cv=5, verbose=2, n_jobs=-1)

grid_search.fit(train_data[all_cols], train_data["connected"])


best_model = grid_search.best_estimator_

test_data['pred'] = best_model.predict(test_data[all_cols])

Fitting 5 folds for each of 64 candidates, totalling 320 fits
[CV] END model__C=1e-05, model__class_weight=None, model__penalty=l1, model__solver=liblinear; total time=   1.1s
[CV] END model__C=1e-05, model__class_weight=None, model__penalty=l1, model__solver=liblinear; total time=   1.1s
[CV] END model__C=1e-05, model__class_weight=None, model__penalty=l1, model__solver=saga; total time=   1.2s
[CV] END model__C=1e-05, model__class_weight=None, model__penalty=l1, model__solver=liblinear; total time=   1.2s
[CV] END model__C=1e-05, model__class_weight=None, model__penalty=l1, model__solver=liblinear; total time=   1.2s
[CV] END model__C=1e-05, model__class_weight=None, model__penalty=l1, model__solver=liblinear; total time=   1.1s
[CV] END model__C=1e-05, model__class_weight=None, model__penalty=l1, model__solver=saga; total time=   1.2s
[CV] END model__C=1e-05, model__class_weight=None, model__penalty=l1, model__solver=saga; total time=   1.0s
[CV] END model__C=1e-05, model__class_wei



[CV] END model__C=1, model__class_weight=balanced, model__penalty=l2, model__solver=liblinear; total time=   0.8s
[CV] END model__C=1, model__class_weight=balanced, model__penalty=l2, model__solver=liblinear; total time=   0.9s




[CV] END model__C=1, model__class_weight=None, model__penalty=l1, model__solver=saga; total time=  54.3s
[CV] END model__C=1, model__class_weight=None, model__penalty=l1, model__solver=saga; total time=  54.2s
[CV] END model__C=1, model__class_weight=None, model__penalty=l1, model__solver=saga; total time=  54.8s




[CV] END model__C=1, model__class_weight=balanced, model__penalty=l2, model__solver=liblinear; total time=   1.1s




[CV] END model__C=1, model__class_weight=None, model__penalty=l1, model__solver=saga; total time=  54.4s
[CV] END model__C=1, model__class_weight=balanced, model__penalty=l2, model__solver=liblinear; total time=   1.2s
[CV] END model__C=1, model__class_weight=balanced, model__penalty=l2, model__solver=liblinear; total time=   1.3s
[CV] END model__C=1, model__class_weight=balanced, model__penalty=l2, model__solver=saga; total time=  28.9s
[CV] END model__C=1, model__class_weight=balanced, model__penalty=l2, model__solver=saga; total time=  29.5s
[CV] END model__C=1, model__class_weight=balanced, model__penalty=l2, model__solver=saga; total time=  31.1s
[CV] END model__C=1, model__class_weight=balanced, model__penalty=l2, model__solver=saga; total time=  30.0s
[CV] END model__C=5, model__class_weight=None, model__penalty=l1, model__solver=liblinear; total time=   2.5s
[CV] END model__C=1, model__class_weight=balanced, model__penalty=l2, model__solver=saga; total time=  31.8s
[CV] END mod



[CV] END model__C=1, model__class_weight=balanced, model__penalty=l1, model__solver=saga; total time=  47.7s




[CV] END model__C=5, model__class_weight=None, model__penalty=l2, model__solver=liblinear; total time=   0.8s
[CV] END model__C=1, model__class_weight=balanced, model__penalty=l1, model__solver=saga; total time=  47.9s
[CV] END model__C=5, model__class_weight=None, model__penalty=l2, model__solver=liblinear; total time=   0.8s
[CV] END model__C=5, model__class_weight=None, model__penalty=l2, model__solver=liblinear; total time=   1.1s
[CV] END model__C=5, model__class_weight=None, model__penalty=l2, model__solver=liblinear; total time=   0.8s
[CV] END model__C=5, model__class_weight=None, model__penalty=l2, model__solver=liblinear; total time=   0.8s
[CV] END model__C=5, model__class_weight=None, model__penalty=l2, model__solver=saga; total time=   2.1s




[CV] END model__C=1, model__class_weight=balanced, model__penalty=l1, model__solver=saga; total time=  46.0s
[CV] END model__C=5, model__class_weight=None, model__penalty=l2, model__solver=saga; total time=   2.2s




[CV] END model__C=1, model__class_weight=balanced, model__penalty=l1, model__solver=saga; total time=  45.9s
[CV] END model__C=1, model__class_weight=balanced, model__penalty=l1, model__solver=saga; total time=  46.6s
[CV] END model__C=5, model__class_weight=None, model__penalty=l2, model__solver=saga; total time=   2.0s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l1, model__solver=liblinear; total time=   1.8s
[CV] END model__C=5, model__class_weight=None, model__penalty=l2, model__solver=saga; total time=   2.5s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l1, model__solver=liblinear; total time=   2.5s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l1, model__solver=liblinear; total time=   1.6s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l1, model__solver=liblinear; total time=   2.0s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l1, model__solver=liblinear; total time=   1.5s
[C



[CV] END model__C=5, model__class_weight=None, model__penalty=l1, model__solver=saga; total time=  43.9s




[CV] END model__C=5, model__class_weight=None, model__penalty=l1, model__solver=saga; total time=  44.2s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l2, model__solver=liblinear; total time=   0.8s
[CV] END model__C=5, model__class_weight=None, model__penalty=l1, model__solver=saga; total time=  43.9s




[CV] END model__C=5, model__class_weight=None, model__penalty=l1, model__solver=saga; total time=  44.1s




[CV] END model__C=5, model__class_weight=None, model__penalty=l1, model__solver=saga; total time=  44.1s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l2, model__solver=liblinear; total time=   0.9s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l2, model__solver=liblinear; total time=   1.3s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l2, model__solver=liblinear; total time=   0.9s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l2, model__solver=liblinear; total time=   1.0s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l2, model__solver=saga; total time=   2.4s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l2, model__solver=saga; total time=   2.1s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l2, model__solver=saga; total time=   2.5s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l2, model__solver=saga; total time=   2.6s
[CV



[CV] END model__C=10, model__class_weight=None, model__penalty=l2, model__solver=liblinear; total time=   1.1s
[CV] END model__C=10, model__class_weight=None, model__penalty=l1, model__solver=saga; total time=   2.5s
[CV] END model__C=10, model__class_weight=None, model__penalty=l2, model__solver=liblinear; total time=   0.7s




[CV] END model__C=5, model__class_weight=balanced, model__penalty=l1, model__solver=saga; total time=  45.2s
[CV] END model__C=10, model__class_weight=None, model__penalty=l2, model__solver=liblinear; total time=   1.2s




[CV] END model__C=5, model__class_weight=balanced, model__penalty=l1, model__solver=saga; total time=  44.9s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l1, model__solver=saga; total time=  45.1s
[CV] END model__C=5, model__class_weight=balanced, model__penalty=l1, model__solver=saga; total time=  45.1s
[CV] END model__C=10, model__class_weight=None, model__penalty=l2, model__solver=saga; total time=   2.8s
[CV] END model__C=10, model__class_weight=None, model__penalty=l2, model__solver=saga; total time=   2.2s
[CV] END model__C=10, model__class_weight=balanced, model__penalty=l1, model__solver=liblinear; total time=   2.0s
[CV] END model__C=10, model__class_weight=balanced, model__penalty=l1, model__solver=liblinear; total time=   2.2s
[CV] END model__C=10, model__class_weight=None, model__penalty=l2, model__solver=saga; total time=   3.2s
[CV] END model__C=10, model__class_weight=None, model__penalty=l2, model__solver=saga; total time=   2.6s
[CV] END model__C=1



[CV] END model__C=5, model__class_weight=balanced, model__penalty=l1, model__solver=saga; total time=  43.5s




[CV] END model__C=10, model__class_weight=None, model__penalty=l1, model__solver=saga; total time=  37.8s




[CV] END model__C=10, model__class_weight=None, model__penalty=l2, model__solver=saga; total time=  33.3s




[CV] END model__C=10, model__class_weight=balanced, model__penalty=l1, model__solver=saga; total time=  32.3s
[CV] END model__C=10, model__class_weight=balanced, model__penalty=l2, model__solver=saga; total time=  28.7s


In [9]:
best_model.steps[2][1], best_model.steps[2][1].penalty


(LogisticRegression(C=0.1, max_iter=300, random_state=2, solver='saga'), 'l2')

In [12]:
# Compute accuracy
accuracy = accuracy_score(test_data['connected'], test_data['pred'] > 0.5)
print(f"Accuracy: {accuracy}")

# Compute balanced accuracy
balanced_accuracy = balanced_accuracy_score(test_data['connected'], test_data['pred'] > 0.5)
print(f"Balanced Accuracy: {balanced_accuracy}")

# Display the confusion matrix
conf_matrix = confusion_matrix(test_data['connected'], test_data['pred'] > 0.5)
print(conf_matrix)


Accuracy: 0.7174644173594855
Balanced Accuracy: 0.780772993522433
[[26437 10459]
 [   42   229]]


In [26]:
# Compute accuracy
accuracy = accuracy_score(test_data['connected'], test_data['pred'] > 0.5)
print(f"Accuracy: {accuracy}")

# Compute balanced accuracy
balanced_accuracy = balanced_accuracy_score(test_data['connected'], test_data['pred'] > 0.5)
print(f"Balanced Accuracy: {balanced_accuracy}")

# Display the confusion matrix
conf_matrix = confusion_matrix(test_data['connected'], test_data['pred'] > 0.5)
print(conf_matrix)


Accuracy: 0.7174644173594855
Balanced Accuracy: 0.780772993522433
[[26437 10459]
 [   42   229]]


ADASYN:

Accuracy: 0.7106034923453601
Balanced Accuracy: 0.7773173343723897
[[26182 10714]
 [   42   229]]

smote

Accuracy: 0.7106034923453601
Balanced Accuracy: 0.7773173343723897
[[26182 10714]
 [   42   229]]

.7379

ACC WITH PROJ:
Accuracy: 0.6564156375279145
Balanced Accuracy: 0.7372041349695804
[[24175 12721]
 [   49   222]]