In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

#load in training data on each potential synapse
data = pd.read_csv("./train_data.csv")

#load in additional features for each neuron
feature_weights = pd.read_csv("./feature_weights.csv")
morph_embeddings = pd.read_csv("./morph_embeddings.csv")

In [2]:
# join all feature_weight_i columns into a single np.array column
feature_weights["feature_weights"] = (
    feature_weights.filter(regex="feature_weight_")
    .sort_index(axis=1)
    .apply(lambda x: np.array(x), axis=1)
)
# delete the feature_weight_i columns
feature_weights.drop(
    feature_weights.filter(regex="feature_weight_").columns, axis=1, inplace=True
)

# join all morph_embed_i columns into a single np.array column
morph_embeddings["morph_embeddings"] = (
    morph_embeddings.filter(regex="morph_emb_")
    .sort_index(axis=1)
    .apply(lambda x: np.array(x), axis=1)
)
# delete the morph_embed_i columns
morph_embeddings.drop(
    morph_embeddings.filter(regex="morph_emb_").columns, axis=1, inplace=True
)

In [3]:
data = (
    data.merge(
        feature_weights.rename(columns=lambda x: "pre_" + x), 
        how="left", 
        validate="m:1",
        copy=False,
    )
    .merge(
        feature_weights.rename(columns=lambda x: "post_" + x),
        how="left",
        validate="m:1",
        copy=False,
    )
    .merge(
        morph_embeddings.rename(columns=lambda x: "pre_" + x),
        how="left",
        validate="m:1",
        copy=False,
    )
    .merge(
        morph_embeddings.rename(columns=lambda x: "post_" + x),
        how="left",
        validate="m:1",
        copy=False,
    )
)

In [4]:
data.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 185832 entries, 0 to 185831
Data columns (total 34 columns):
 #   Column                          Non-Null Count   Dtype  
---  ------                          --------------   -----  
 0   ID                              185832 non-null  int64  
 1   axonal_coor_x                   185832 non-null  int64  
 2   axonal_coor_y                   185832 non-null  int64  
 3   axonal_coor_z                   185832 non-null  int64  
 4   dendritic_coor_x                185832 non-null  int64  
 5   dendritic_coor_y                185832 non-null  int64  
 6   dendritic_coor_z                185832 non-null  int64  
 7   adp_dist                        185832 non-null  float64
 8   post_skeletal_distance_to_soma  185832 non-null  float64
 9   pre_skeletal_distance_to_soma   185832 non-null  float64
 10  pre_oracle                      185832 non-null  float64
 11  pre_test_score                  185832 non-null  float64
 12  pre_rf_x        

In [5]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity

In [6]:
#cosine similarity function
def row_feature_similarity(row):
    pre = row["pre_feature_weights"]
    post = row["post_feature_weights"]
    return (pre * post).sum() / (np.linalg.norm(pre) * np.linalg.norm(post))

In [7]:
# compute the cosine similarity between the pre- and post- feature weights
data["fw_similarity"] = data.apply(row_feature_similarity, axis=1)

In [8]:
# generate projection group as pre->post
data["projection_group"] = (
    data["pre_brain_area"].astype(str)
    + "->"
    + data["post_brain_area"].astype(str)
)

In [9]:
data.head()

Unnamed: 0,ID,axonal_coor_x,axonal_coor_y,axonal_coor_z,dendritic_coor_x,dendritic_coor_y,dendritic_coor_z,adp_dist,post_skeletal_distance_to_soma,pre_skeletal_distance_to_soma,...,post_nucleus_z,pre_nucleus_id,post_nucleus_id,connected,pre_feature_weights,post_feature_weights,pre_morph_embeddings,post_morph_embeddings,fw_similarity,projection_group
0,42593,1187660,411978,1089020,1187390,412220,1089160,304.185,353043.0,1182170.0,...,919560,557121,518848,False,"[0.40828925, 0.051097646, -0.02682111, 0.04239...","[-0.03917461, -0.1830603, -0.3704222, 0.183293...","[0.3733156323432922, 0.209817960858345, -0.123...","[1.0723994970321655, -0.7540942430496216, 0.11...",0.127256,RL->RL
1,42594,1204580,682542,873138,1204640,682870,873890,725.431,244156.0,914243.0,...,919560,557121,518848,False,"[0.40828925, 0.051097646, -0.02682111, 0.04239...","[-0.03917461, -0.1830603, -0.3704222, 0.183293...","[0.3733156323432922, 0.209817960858345, -0.123...","[1.0723994970321655, -0.7540942430496216, 0.11...",0.127256,RL->RL
2,42595,1191790,403683,1093180,1188590,402414,1092660,3423.03,363829.0,1171820.0,...,919560,557121,518848,False,"[0.40828925, 0.051097646, -0.02682111, 0.04239...","[-0.03917461, -0.1830603, -0.3704222, 0.183293...","[0.3733156323432922, 0.209817960858345, -0.123...","[1.0723994970321655, -0.7540942430496216, 0.11...",0.127256,RL->RL
3,42596,1184320,419286,1082930,1186620,419721,1085540,3442.39,344267.0,1192340.0,...,919560,557121,518848,False,"[0.40828925, 0.051097646, -0.02682111, 0.04239...","[-0.03917461, -0.1830603, -0.3704222, 0.183293...","[0.3733156323432922, 0.209817960858345, -0.123...","[1.0723994970321655, -0.7540942430496216, 0.11...",0.127256,RL->RL
4,42597,1189150,673302,944202,1188790,677771,942901,4442.38,313630.0,788566.0,...,919560,557121,518848,False,"[0.40828925, 0.051097646, -0.02682111, 0.04239...","[-0.03917461, -0.1830603, -0.3704222, 0.183293...","[0.3733156323432922, 0.209817960858345, -0.123...","[1.0723994970321655, -0.7540942430496216, 0.11...",0.127256,RL->RL


In [11]:
import torch
X_train = torch.FloatTensor(data[['fw_similarity', 'adp_dist']].values)
y_train = torch.FloatTensor(data['connected'].values).view(-1, 1)
X_test = torch.FloatTensor(data[['fw_similarity', 'adp_dist']].values)
y_test = torch.FloatTensor(data['connected'].values).view(-1, 1)

In [12]:
import torch.nn as nn
import torch.optim as optim

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(2, 32)
        self.fc2 = nn.Linear(32, 16)
        self.fc3 = nn.Linear(16, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x
