## Preparing perturbation node features
This part contains three main tasks. First, we will select a certain part of the Cmap gene expression signatures which are the signatures that were induced with only Crisper cas-9 genetic modification. This is because we would like to take the model training incrementally by focusing the learning on one genetic modification type at a time instead of generalizing the learning. The second main task in this part will be the construction of the synthetic "therapeutic" node feature which is the reversed disease signature. Finally, the last main task will be dimensionality reduction of the node features using principal component analysis.

### 1. Selecting Criper cas-9 induced signatures

In [1]:
import pandas as pd
import numpy as np

In [8]:
df = pd.read_csv("20023_6056_signatures.csv")
df = df.iloc[:,1:]
df = df.sort_index()
df

Unnamed: 0,100,10000,10001,10005,10006,10007,10010,100129250,100129482,100131755,...,9973,9976,9978,998,9980,9987,9988,9990,9991,9997
0,0.174898,0.170660,-0.327706,-1.012201,-0.709069,0.590302,0.317491,-0.336832,-0.622616,-1.192824,...,-0.020990,-0.943036,-0.457074,-0.451099,-0.275676,0.025439,-0.940749,0.108191,0.012258,0.436765
1,0.023692,-0.621700,-0.214531,0.067540,-0.066525,0.305235,0.673813,0.026184,-0.026370,0.469255,...,0.625724,0.280934,-0.312827,0.258809,0.337565,-0.288991,-0.516484,0.399671,-0.623557,-0.335016
2,0.406850,0.775148,-0.288365,1.054073,-0.470195,0.224927,-0.055932,-0.041696,-0.939581,0.170502,...,-0.822419,0.436913,-0.108121,0.189611,0.088665,-0.411371,0.159655,-0.312251,-0.004310,0.327478
3,0.012266,0.370053,0.905270,0.524906,1.090123,-1.512671,-0.746616,-0.560953,-0.119623,-0.714175,...,-0.903741,-0.134229,0.230682,0.673075,-0.450610,0.643537,-0.437256,-0.244052,-0.270511,-0.391445
4,0.680962,-1.441959,0.123246,-0.414082,-0.119506,-0.310105,0.768365,0.300217,0.037321,0.057293,...,-1.274699,0.608310,-0.498205,-1.115365,-0.178106,-0.262965,0.487434,-0.042183,1.015008,-0.002290
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20018,1.630630,-1.057249,-1.483222,-0.694036,0.392080,-0.302878,-0.181074,-1.345645,-0.120782,1.596910,...,0.896386,0.376598,0.010688,-0.353705,-0.522536,-0.718549,-1.132831,-0.174150,-0.610226,0.576494
20019,1.910709,-0.613249,-1.384393,1.606892,-0.516433,-0.464828,0.489844,0.257911,-0.760184,0.645760,...,1.832853,-0.387472,-0.745996,-0.018421,-0.794105,-0.904071,-0.857446,-0.791354,1.414317,-0.423272
20020,1.285126,0.607435,-1.041552,-0.743960,-0.831314,-0.496065,-0.702506,0.159211,0.202441,1.274807,...,1.931300,-0.839873,0.334764,-0.183769,-0.945705,-0.867503,-0.121041,0.739108,-0.742083,0.058644
20021,0.069787,-0.865963,-0.208207,-0.925451,-0.413606,0.946464,0.211409,-0.306070,-1.471482,-1.303913,...,-0.597694,-0.450268,-0.335929,0.760301,-0.961826,0.312094,-0.615488,-1.047882,0.030394,-0.519334


In [9]:
xpr = pd.read_csv("xpr_uniq.csv")
xpr

Unnamed: 0,sig_id,pert_type,gene_id
0,XPR009_U251MG.311_96H:D02,trt_xpr,5972
1,XPR025_U251MG.311_96H:J07,trt_xpr,9496
2,XPR009_U251MG.311_96H:P11,trt_xpr,213
3,XPR010_U251MG.311_96H:K21,trt_xpr,2520
4,XPR008_U251MG.311_96H:J20,trt_xpr,3630
...,...,...,...
5522,XPR024_U251MG.311_96H:H23,trt_xpr,6594
5523,XPR024_U251MG.311_96H:J15,trt_xpr,6595
5524,XPR024_U251MG.311_96H:J03,trt_xpr,6598
5525,XPR024_U251MG.311_96H:H13,trt_xpr,6602


In [10]:
# genperts_dictionary from json file
import json

file_path = 'mapping_dicts/genperts_dict.json'
with open(file_path, 'r') as json_file:
    genperts_json = json.load(json_file)

#convert back strings into integers
    genperts_dict = {value: str(key) for key, value in genperts_json.items()}
    
print(len(genperts_dict))

20023


In [11]:
# genperts_dictionary from json file

file_path = 'mapping_dicts/xpr_uniq_dict.json'
with open(file_path, 'r') as json_file:
    xpr_uniq_json = json.load(json_file)

#convert back strings into integers as they got mutated in json conversion! 
    xpr_uniq_dict = {str(key): value for key, value in xpr_uniq_json.items()}
    
print(len(xpr_uniq_dict))

5527


In [12]:
df.index = df.index.map(genperts_dict)
df = df[df.index.isin(xpr['sig_id'])]
df.index = df.index.map(xpr_uniq_dict)
df

Unnamed: 0,100,10000,10001,10005,10006,10007,10010,100129250,100129482,100131755,...,9973,9976,9978,998,9980,9987,9988,9990,9991,9997
0,-0.068898,-1.385201,-0.350640,-0.507029,-0.866744,0.012788,-0.741100,0.018638,-0.483464,-0.501729,...,0.836564,0.130130,0.739138,0.704488,-0.059225,1.096602,-1.077828,0.807338,-0.491475,-0.043647
1,-0.837957,-0.424723,0.238129,-0.257697,-0.331083,0.621665,0.632254,-0.958893,-0.324983,-0.385073,...,0.208308,0.309714,0.599529,0.058085,-0.860963,0.691762,0.436001,-0.155574,0.219674,0.053126
2,0.744123,0.122277,0.891896,-0.328410,0.357296,-0.883692,-0.578787,-0.182959,-1.051222,0.311895,...,0.635955,-0.680460,-0.371369,0.094702,0.141131,0.471570,-0.101115,-1.179156,0.675766,-0.053197
3,-0.056949,-0.566240,0.222091,0.129431,-0.764395,-0.487019,-0.660915,0.181364,-0.112409,0.503623,...,-0.224525,0.394164,0.495686,0.642808,0.237943,0.580772,0.416120,-0.410499,0.301205,-0.858144
4,0.046869,-1.095164,0.821904,-0.432146,0.046760,-0.191532,1.283893,0.095794,-1.245688,-1.169470,...,0.661266,0.255379,0.094150,0.155992,-0.693762,-0.001730,-0.427046,0.465716,0.302346,-0.667659
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5522,0.232098,0.504555,0.394242,0.405648,-0.017280,0.113865,-0.468600,-0.466273,0.801038,0.570780,...,1.020092,0.120485,0.264322,-0.163068,0.244077,0.664830,0.937642,-0.366926,0.677980,-0.212613
5523,-0.531510,-0.528307,1.200992,0.303188,-0.296790,0.521208,-0.188038,-1.280407,-0.214688,-0.860647,...,-0.190510,-0.601682,-0.537331,0.733204,-0.483163,0.349452,-0.462393,-0.041428,1.041209,-0.981610
5524,1.474548,0.788736,0.455317,-0.894020,1.268071,-1.083572,0.155847,1.660888,0.554500,2.460362,...,2.132888,1.720682,1.610572,-0.559167,1.990002,-2.424774,-0.842418,1.526391,-0.627741,0.533655
5525,-0.589825,-0.105034,-0.655202,-0.244863,-1.090093,0.256787,-1.421959,0.117661,-0.302168,0.387863,...,-0.374813,0.184845,1.352266,-0.038880,0.601953,-0.564261,-0.743027,-1.255958,-1.276335,-0.020384


### 2. Therapeutic node features
We will now prepare the node features for the synthetic "therapeutic" node. In this part, the disease signature, that was obtained from the analysis of the disease cohort, will be reversed to represent a complete inversion of the disease state. The synthetic node features will then be added to the node feature matrix before PCA is applied.

In [14]:
# Read in the disease-signature file
AD_sig = pd.read_csv("t_based_signature.csv")
AD_sig = AD_sig.iloc[:,1:]
AD_sig['Gene.ID'] = AD_sig['Gene.ID'].astype('str')
AD_sig = AD_sig[AD_sig['Gene.ID'].isin(df.columns)]
AD_sig

Unnamed: 0,Gene.ID,t
0,100,-1.563326
2,10000,0.821559
3,10001,2.538692
6,10005,2.167191
7,10006,1.243114
...,...,...
9941,9987,2.016958
9942,9988,2.626368
9945,9990,1.838147
9946,9991,-0.992873


In [17]:
t_scores = AD_sig['t']

In [18]:
# Define the signature reversion fucntion
def reverse_signature(signature):
    reversed_signature = signature.apply(lambda x: -x)
        
    return reversed_signature

reversed_signature = reverse_signature(t_scores)
reversed_signature

0       1.563326
2      -0.821559
3      -2.538692
6      -2.167191
7      -1.243114
          ...   
9941   -2.016958
9942   -2.626368
9945   -1.838147
9946    0.992873
9948   -4.377658
Name: t, Length: 6056, dtype: float64

In [19]:
reversed_signature = pd.DataFrame(reversed_signature)
reversed_signature

Unnamed: 0,t
0,1.563326
2,-0.821559
3,-2.538692
6,-2.167191
7,-1.243114
...,...
9941,-2.016958
9942,-2.626368
9945,-1.838147
9946,0.992873


In [20]:
# Add the reversed signature to the node feature matrix
snythetic_node_signature = pd.DataFrame(reversed_signature.T)
snythetic_node_signature.columns = df.columns
snythetic_node_signature

Unnamed: 0,100,10000,10001,10005,10006,10007,10010,100129250,100129482,100131755,...,9973,9976,9978,998,9980,9987,9988,9990,9991,9997
t,1.563326,-0.821559,-2.538692,-2.167191,-1.243114,-5.251242,-1.0704,-2.641217,0.121646,-2.453519,...,-2.523162,2.175655,-6.919617,-2.615592,-1.318495,-2.016958,-2.626368,-1.838147,0.992873,-4.377658


In [21]:
df = pd.concat([df,snythetic_node_signature], axis=0, ignore_index=True)
df

Unnamed: 0,100,10000,10001,10005,10006,10007,10010,100129250,100129482,100131755,...,9973,9976,9978,998,9980,9987,9988,9990,9991,9997
0,-0.068898,-1.385201,-0.350640,-0.507029,-0.866744,0.012788,-0.741100,0.018638,-0.483464,-0.501729,...,0.836564,0.130130,0.739138,0.704488,-0.059225,1.096602,-1.077828,0.807338,-0.491475,-0.043647
1,-0.837957,-0.424723,0.238129,-0.257697,-0.331083,0.621665,0.632254,-0.958893,-0.324983,-0.385073,...,0.208308,0.309714,0.599529,0.058085,-0.860963,0.691762,0.436001,-0.155574,0.219674,0.053126
2,0.744123,0.122277,0.891896,-0.328410,0.357296,-0.883692,-0.578787,-0.182959,-1.051222,0.311895,...,0.635955,-0.680460,-0.371369,0.094702,0.141131,0.471570,-0.101115,-1.179156,0.675766,-0.053197
3,-0.056949,-0.566240,0.222091,0.129431,-0.764395,-0.487019,-0.660915,0.181364,-0.112409,0.503623,...,-0.224525,0.394164,0.495686,0.642808,0.237943,0.580772,0.416120,-0.410499,0.301205,-0.858144
4,0.046869,-1.095164,0.821904,-0.432146,0.046760,-0.191532,1.283893,0.095794,-1.245688,-1.169470,...,0.661266,0.255379,0.094150,0.155992,-0.693762,-0.001730,-0.427046,0.465716,0.302346,-0.667659
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5523,-0.531510,-0.528307,1.200992,0.303188,-0.296790,0.521208,-0.188038,-1.280407,-0.214688,-0.860647,...,-0.190510,-0.601682,-0.537331,0.733204,-0.483163,0.349452,-0.462393,-0.041428,1.041209,-0.981610
5524,1.474548,0.788736,0.455317,-0.894020,1.268071,-1.083572,0.155847,1.660888,0.554500,2.460362,...,2.132888,1.720682,1.610572,-0.559167,1.990002,-2.424774,-0.842418,1.526391,-0.627741,0.533655
5525,-0.589825,-0.105034,-0.655202,-0.244863,-1.090093,0.256787,-1.421959,0.117661,-0.302168,0.387863,...,-0.374813,0.184845,1.352266,-0.038880,0.601953,-0.564261,-0.743027,-1.255958,-1.276335,-0.020384
5526,0.287277,0.403055,-1.104795,-0.256377,-1.121449,0.122438,-0.672024,-0.001600,-0.781647,-0.679878,...,0.434054,-1.230185,-0.683082,-0.395456,-0.938981,0.696579,-0.689716,-0.458554,-0.279167,-0.675389


### 3. Dimensionality reduction
Finally, we will use the PCA technique to reduce the large dimensional space of the gene expression signatures (6) to avoid hyperparameterization of the model. The first 512 components from each signature will be the final node function. This number of components is considered optimal in terms of the number of model parameters and the explained variance it preserves (~0.95).

In [22]:
from sklearn.decomposition import PCA

pca = PCA(n_components=512)  
reduced_data = pca.fit_transform(df)

# Get the explained variance ratio for each selected component
explained_variance = pca.explained_variance_ratio_

# Calculate the total variance explained
total_variance_explained = explained_variance.sum()

print("Explained variance ratio for each component:", explained_variance)
print("Total variance explained by the selected components:", total_variance_explained)


Explained variance ratio for each component: [0.144315   0.10548731 0.0251943  0.02223336 0.02041975 0.01662845
 0.01424112 0.01139389 0.0090213  0.00830117 0.00785139 0.00759166
 0.00737512 0.00706748 0.00648701 0.00624338 0.00597462 0.00595453
 0.00562137 0.0055211  0.00526466 0.00517896 0.00499365 0.0049118
 0.00477402 0.00468096 0.00455771 0.00436442 0.00430517 0.00422715
 0.00417039 0.0040164  0.00399373 0.00388883 0.00382069 0.00370933
 0.00368876 0.00361245 0.00356384 0.00345617 0.00342189 0.00340201
 0.00333784 0.00328756 0.00323353 0.00317122 0.00315221 0.00311202
 0.00305234 0.00297641 0.00297021 0.00293523 0.00290255 0.00285692
 0.00282395 0.00277376 0.00274108 0.00267989 0.00267567 0.00262628
 0.00262049 0.00258739 0.00257706 0.00253963 0.00248747 0.00246509
 0.00245382 0.00239997 0.00238262 0.00235289 0.00233682 0.00230127
 0.00226411 0.00224541 0.00223782 0.00219381 0.00218937 0.00217013
 0.00215608 0.00214264 0.00212144 0.00208685 0.00207557 0.00205608
 0.00204095 0.0020

In [23]:
print(reduced_data.shape)

(5528, 512)


In [24]:
reduced_data_df = pd.DataFrame(reduced_data)
reduced_data_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,502,503,504,505,506,507,508,509,510,511
0,-11.605774,-11.112385,4.197629,0.694979,-0.679183,0.000990,1.684370,1.867259,-6.410202,0.840989,...,0.921559,1.766225,-0.635588,-0.122904,0.706343,0.857873,-0.908231,-0.695669,0.191792,-0.483671
1,-0.480840,-12.886077,1.156119,-9.369387,-12.271465,-5.299552,1.814872,1.231757,-0.795256,7.244125,...,-0.682732,0.276837,1.003047,-1.080344,-1.102473,-1.416938,0.402696,-2.642268,0.586465,-0.731499
2,1.981056,-12.529287,2.230314,-8.712580,1.474912,-0.346716,-3.148813,3.300624,3.352898,1.445657,...,-0.608449,-0.640017,0.413258,-1.368983,0.223713,0.477711,0.425333,-0.186627,0.062133,0.696777
3,11.394960,-6.205779,5.738248,-5.676596,-1.547601,2.967252,1.724745,-5.372466,4.032555,3.903292,...,-1.209299,-0.211278,-0.831071,-0.151926,-1.306441,0.577610,0.594553,1.175333,-0.510210,-0.404997
4,-12.668876,-17.909757,15.574519,-11.707729,1.025009,-10.637166,9.452777,-5.774369,5.954382,1.701811,...,-1.667598,0.556815,1.045176,0.056079,-0.584343,0.085227,0.344022,-0.044363,-0.629879,-1.281119
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5523,-23.093196,-23.525455,-8.188239,-1.683831,-1.509904,-0.099511,0.116826,-5.021516,-0.645964,-3.190032,...,1.016368,-0.221268,-0.388789,1.241723,1.284948,-0.589265,0.097013,-0.453488,-0.352934,0.566398
5524,73.848267,11.103146,37.750644,-10.529579,17.920021,-7.318146,-0.912092,-15.001574,0.839160,-1.381615,...,0.956501,-1.153842,-0.835331,1.488341,1.849079,0.931293,0.398589,1.377784,-0.246252,-0.687388
5525,16.116392,18.346034,-7.300534,8.619267,-19.476395,-6.250095,-9.509305,2.937957,-9.139140,5.399104,...,0.923417,0.685783,-0.358618,0.478194,-0.609039,-0.947595,-1.236446,-1.091750,-0.887053,0.218692
5526,-40.867333,2.810133,8.345659,2.838778,-2.348983,-5.479776,4.596636,7.087862,-5.183913,4.669562,...,-0.229968,0.025007,-0.285935,-0.434765,-0.128316,-0.777172,0.636396,0.609945,0.759864,0.232392


In [25]:
np.save("perts_pca_node_features.npy", reduced_data)