In [1]:
import scanpy as sc

from warnings import filterwarnings
filterwarnings('ignore')

from sklearn.model_selection import GridSearchCV

from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
import pandas as pd

from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score
import matplotlib.pyplot as plt

import pickle
import numpy as np

from scipy.stats import ks_2samp
import seaborn as sns

In [2]:
adata = sc.read_h5ad('sc_training.h5ad')
adata.layers['normalized_logcounts'] = adata.X.copy()
adata.X = adata.layers['rawcounts'].copy()  

sc.pp.normalize_total(adata, target_sum=5e3)
sc.pp.log1p(adata)

In [3]:
normdf = adata.to_df(layer="normalized_logcounts")

In [4]:
df_cond = pd.DataFrame(adata.obs['condition'])
unpert_sample_in = df_cond[df_cond['condition'] == 'Unperturbed'].index

df_unpert = normdf.filter(unpert_sample_in, axis = 0)

In [5]:
def filter(normdf):
    var_per_gene = normdf.var(axis = 0)
    filt = var_per_gene[var_per_gene >= ((var_per_gene.sort_values()[-1:][0])/100)]
    X = normdf.filter(filt.index.values, axis = 1)
    return X

In [6]:
filter(df_unpert).head()

Unnamed: 0,Mrpl15,Lypla1,Tcea1,Atp6v1h,Rb1cc1,Pcmtd1,Rrs1,Vcpip1,Snhg6,Cops5,...,mt-Nd3,mt-Nd4l,mt-Nd4,mt-Nd5,mt-Nd6,mt-Cytb,CAAA01118383.1,Vamp7,CAAA01147332.1,AC149090.1
053l1_AAACCTGAGATGTCGG-1,0.51152,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.610826,1.099769,0.0,0.0,1.73603,0.0,0.0,0.0,0.0
053l1_AAACCTGAGTGTTAGA-1,0.0,1.089103,0.686024,0.686024,0.0,0.0,0.0,0.0,0.0,0.0,...,0.686024,1.375591,0.686024,0.0,0.686024,1.779859,0.686024,0.0,0.0,0.0
053l1_AAACCTGCATAGACTC-1,0.292065,0.292065,0.292065,0.292065,0.292065,0.0,0.292065,0.0,0.0,0.517829,...,0.517829,2.135607,1.216197,1.479762,0.0,2.284313,0.292065,0.0,0.292065,0.0
053l1_AAACGGGAGTGGAGAA-1,1.020401,1.020401,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.020401,0.0,0.0,1.844182,0.0,0.0,1.020401,0.0
053l1_AAACGGGCAATCGAAA-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.836899,2.187799,2.187799,0.0,0.0,2.187799,0.0,0.0,0.0,0.0


In [7]:
#combine perturbation lists to create all genes of interest
pert = adata.obs.condition.values.unique().to_list()
val_pert = ['Aqr', 'Bach2', 'Bhlhe40', 'Ets1', 'Fosb', 'Mafk', 'Stat3']
all_pert = pert + val_pert

In [8]:
def cond_filter(normdf, condition):
    df_cond_list = pd.DataFrame(adata.obs['condition'])
    cond_index = df_cond_list[df_cond_list['condition'] == condition].index
    df_cond = normdf.filter(cond_index, axis = 0)
    return df_cond

In [9]:
def real_perturbseq(df, gene_of_interest, df_svm_trained):
    df_cond = pd.DataFrame(adata.obs['condition'])
    goi_sample_index = df_cond[df_cond['condition'] == gene_of_interest].index
    if gene_of_interest in df.columns and df.filter(goi_sample_index, axis = 0).shape[0] >= 10:
        #get relevant dataset based on gene_of_interest
        pert_goi_df = df.filter(goi_sample_index, axis = 0) #perturbseq data, normalized

        df = pert_goi_df[df_svm_trained.columns]
        df_labels = adata.obs['state'].filter(pert_goi_df[df_svm_trained.columns].index.values)

        loaded_model = pickle.load(open('../saved_models/svc_model_unperturbed.sav', 'rb'))
        result = loaded_model.score(df, df_labels)
        print(result, 'rb')
        preds = loaded_model.predict(df)
        cm = confusion_matrix(df_labels, preds, normalize = 'all')
        print(loaded_model.classes_)
        print(f1_score(df_labels, preds, average=None))
        if len(cm) < 5:
            print('less than 5 categories. not enough data') 
            dist = []
        else:
            dist = [cm[0][0], cm[1][1], cm[2][2], cm[3][3], cm[4][4]]
            print(dist)
    else:
        print('gene of interest not found in columns, or less than 10 samples')
        dist = []   
    return dist
    

In [10]:
def fake_perturbseq(gene_of_interest, X_test, y_test):
    print(gene_of_interest)
    if gene_of_interest in X_test.columns and X_test[X_test[gene_of_interest] == 0].shape[0] > 0:
        goi_df = X_test[X_test[gene_of_interest] == 0]
        goi_labels = y_test.filter(goi_df.index.values, axis = 0)
        loaded_model = pickle.load(open('../saved_models/svc_model_unperturbed.sav', 'rb'))
        result = loaded_model.score(goi_df, goi_labels)
        print(result, 'rb')
        preds = loaded_model.predict(goi_df)
        cm = confusion_matrix(goi_labels, preds, normalize = 'all')
        print(loaded_model.classes_)
        print(f1_score(goi_labels, preds, average=None))
        if len(cm) < 5: 
            dist = cm
        else:
            dist = [cm[0][0], cm[1][1], cm[2][2], cm[3][3], cm[4][4]]
            print(dist)
    else:
        print('gene of interest not found in columns')
        dist = []
    return dist


In [11]:
df = filter(df_unpert)
labels = adata.obs['state'].filter(df.index.values, axis = 0)

In [14]:
Aqr = fake_perturbseq('Aqr', filter(df_unpert), labels)

Aqr
0.9896542335965152 rb
['cycling' 'effector' 'other' 'progenitor' 'terminal exhausted']
[0.99646365 0.98466594 0.96124031 0.98728814 0.98820556]
[0.3452218894636537, 0.24475905254560304, 0.016879934658317452, 0.06343588347399945, 0.3193574734549415]


In [None]:
#val_pert = ['Aqr', 'Bach2', 'Bhlhe40', 'Ets1', 'Fosb', 'Mafk', 'Stat3']

In [15]:
Bach2 = fake_perturbseq(val_pert[1], filter(df_unpert), labels)

Bach2
0.9901055408970977 rb
['cycling' 'effector' 'other' 'progenitor' 'terminal exhausted']
[0.99665552 0.98335068 0.96153846 0.98684211 0.98856759]
[0.39313984168865435, 0.20778364116094986, 0.016490765171503958, 0.04947229551451187, 0.3232189973614776]


In [16]:
Bhlhe40 = fake_perturbseq(val_pert[2], filter(df_unpert), labels)

Bhlhe40
0.9875222816399287 rb
['cycling' 'effector' 'other' 'progenitor' 'terminal exhausted']
[0.99253731 0.98591549 0.98591549 0.97297297 0.98870056]
[0.23707664884135474, 0.31194295900178254, 0.062388591800356503, 0.06417112299465241, 0.31194295900178254]


In [17]:
Ets1 = fake_perturbseq(val_pert[3], filter(df_unpert), labels)

Ets1
0.9890909090909091 rb
['cycling' 'effector' 'other' 'progenitor' 'terminal exhausted']
[0.99502488 0.98658718 0.96969697 0.98924731 0.98813056]
[0.2727272727272727, 0.3009090909090909, 0.02909090909090909, 0.08363636363636363, 0.30272727272727273]


In [18]:
Fosb = fake_perturbseq(val_pert[4], filter(df_unpert), labels)

Fosb
0.9900439916647372 rb
['cycling' 'effector' 'other' 'progenitor' 'terminal exhausted']
[0.99622879 0.9847769  0.96052632 0.98905109 0.98842511]
[0.3669830979393378, 0.2171799027552674, 0.016902060662190323, 0.06274600601991202, 0.32623292428802964]


In [19]:
Mafk = fake_perturbseq(val_pert[5], filter(df_unpert), labels)

Mafk
0.9902869757174393 rb
['cycling' 'effector' 'other' 'progenitor' 'terminal exhausted']
[0.99631415 0.98512057 0.96202532 0.98675497 0.98866052]
[0.38785871964679913, 0.2119205298013245, 0.016777041942604858, 0.06578366445916115, 0.3079470198675497]


In [20]:
Stat3 = fake_perturbseq(val_pert[6], filter(df_unpert), labels)

Stat3
0.989821882951654 rb
['cycling' 'effector' 'other' 'progenitor' 'terminal exhausted']
[1.         0.98773006 1.         0.98876404 0.98429319]
[0.18575063613231552, 0.40966921119592875, 0.043256997455470736, 0.11195928753180662, 0.23918575063613232]


In [31]:
validation = pd.DataFrame([Aqr, Bach2, Bhlhe40], columns = ['cycling', 'effector', 'other' ,'progenitor' ,'terminal exhausted'], index = ['Aqr', 'Bach2', 'Bhlhe40'] )
test = pd.DataFrame([Fosb, Mafk, Stat3, Ets1], columns = ['cycling' ,'effector' ,'other', 'progenitor', 'terminal exhausted'], index = ['Fosb', 'Mafk', 'Stat3', 'Ets1'])

In [None]:
test.to_csv('test_output.csv', index = 0)
validation.to_csv('validation_output.csv', index = 0)

NameError: name 'test' is not defined