In [94]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

### Load and summarize some points in the data 

In [75]:
df = pd.read_csv('./data/train_set.csv', index_col=0)

genes_to_predict = list(df.index)
pairs_to_predict = list(pd.read_csv('./data/test_set.csv', header=None)[0])

print('Number of genes to predict = ' + str(len(genes_to_predict)))
print('Number of pairs to predict = ' + str(len(pairs_to_predict)))

Number of genes to predict = 1000
Number of pairs to predict = 50


In [87]:
s_per = [x for x in df.columns if '+ctrl' in x]
s_per = [x[:5] for x in s_per]
s_per = list(set(s_per))
s_per.sort()
print('Number of unique single perturbations in data set = ' + str(len(s_per)))

d_per = [x for x in df.columns if (('+' in x)&('ctrl' not in x))]
d_per = [x[:11] for x in d_per]
d_per = list(set(d_per))
d_per.sort()
print('Number of unique double perturbations in data set = ' + str(len(d_per)))

Number of unique single perturbations in data set = 101
Number of unique double perturbations in data set = 75


### Make prediction

In [107]:
preds = pd.DataFrame(columns=['gene', 'perturbation', 'expression'], index=range(50_000))

j = 0
for i in tqdm(range(len(genes_to_predict))):
# for i in tqdm(range(1)):
    gene = genes_to_predict[i]
    for pair in pairs_to_predict:
    # for k in range(1):
    #     pair = pairs_to_predict[0]
        g1 = pair.split('+')[0]
        g2 = pair.split('+')[1]
        # add control columns
        cols_to_average = [x for x in df.columns if (('ctrl' in x)&(len(x)<=7))]

        # add columns if either gene in the perturbation was part of a previous double perturbation
        for double in d_per:
            if ((g1 in double) | (g2 in double)):
                cols_to_add = [x for x in df.columns if double in x]
                cols_to_average.extend(cols_to_add)
        
        # add columns if either gene in the perturbation was previously sinlgy perturbed
        for single in s_per:
            if ((g1 in single) | (g2 in single)):
                cols_to_add = [x for x in df.columns if (((single+'+') in x)|(('+'+single) in x))]
                cols_to_average.extend(cols_to_add)
        temp = df.loc[gene]
        temp = temp[cols_to_average]
        expression = np.mean(temp)
        preds.iloc[j] = [gene, pair, expression]
        j += 1
preds        

100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [09:18<00:00,  1.79it/s]


Unnamed: 0,gene,perturbation,expression
0,g0001,g0037+g0083,0.36619
1,g0001,g0083+g0605,0.38051
2,g0001,g0095+g0257,0.234398
3,g0001,g0095+g0520,0.21667
4,g0001,g0109+g0317,0.31663
...,...,...,...
49995,g1000,g0924+g0852,0.935889
49996,g1000,g0927+g0186,1.1313
49997,g1000,g0957+g0186,1.109239
49998,g1000,g0957+g0261,0.971783


In [118]:
preds = preds.sort_values(['perturbation','gene'])
preds.to_csv('./prediction/prediction.csv', index=False)

### Test that I have the right order

In [127]:
template = pd.read_csv('./prediction/prediction-template.csv')
template = template[['gene', 'perturbation']]
submission = preds[['gene', 'perturbation']].reset_index(drop=True)

In [130]:
(submission == template).all()

gene            True
perturbation    True
dtype: bool