In [1]:
import torch
import torch.nn as nn
import torch.utils.data
import torch.utils.data as data_utils
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn
from tqdm import tqdm
import timeit

Input data:  
- CCLE_expression.csv  
Gene expression TPM (transcript per million) values of the protein coding genes for DepMap cell lines  
log2(TPM+1) goes from 0.00 to 17.78


- CRISPR_gene_effect.csv
Gene effect of the protein coding genes for DepMap cell lines
Gene Effect scores derived from CRISPR knockout screens published by Broad’s Achilles and Sanger’s SCORE projects.  
Negative scores imply cell growth inhibition and/or death following gene knockout. Scores are normalized such that nonessential genes have a median score of 0 and independently identified common essentials have a median score of -1.  
SCORE goes from -3.89 to 2.73.



Genes:17386  
Cell Lines:1086  
Primary Diseases:31  
Lineages:28  

## Extract data 

In [40]:
# Extract gene expression from CCLE_expression.csv
df = pd.read_csv('data/CCLE_expression.csv')
gene_expression = df.rename(columns={"Unnamed: 0":"DepMap"}).set_index("DepMap").dropna()

# Extract gene effect from CRISPR_gene_effect.csv
df = pd.read_csv('data/CRISPR_gene_effect.csv')
gene_effect = df.rename(columns={"DepMap_ID":"DepMap"}).set_index("DepMap").dropna()

In [41]:
# Extract protein links 9606.protein.links.v11.5.txt
protein_links = pd.read_csv('data/9606.protein.links.v11.5.txt', sep=' ')

# Extract protein info from 9606.protein.info.v11.5.txt
df = pd.read_csv('data/9606.protein.info.v11.5.txt', sep='	')
protein_info = df.rename(columns={"#string_protein_id":"protein",
                   "preferred_name":"gene_ID"})[["protein","gene_ID"]]

In [33]:
gene_expression.head()

Unnamed: 0,DepMap,TSPAN6 (7105),TNMD (64102),DPM1 (8813),SCYL3 (57147),C1orf112 (55732),FGR (2268),CFH (3075),FUCA2 (2519),GCLC (2729),...,H3C2 (8358),H3C3 (8352),AC098582.1 (8916),DUS4L-BCAP29 (115253422),C8orf44-SGK3 (100533105),ELOA3B (728929),NPBWR1 (2831),ELOA3D (100506888),ELOA3 (162699),CDR1 (1038)
0,ACH-001113,4.331992,0.0,7.364397,2.792855,4.470537,0.028569,1.226509,3.042644,6.499686,...,2.689299,0.189034,0.201634,2.130931,0.555816,0.0,0.275007,0.0,0.0,0.0
1,ACH-001289,4.566815,0.584963,7.106537,2.543496,3.50462,0.0,0.189034,3.813525,4.221104,...,1.286881,1.049631,0.321928,1.464668,0.632268,0.0,0.014355,0.0,0.0,0.0
2,ACH-001339,3.15056,0.0,7.379032,2.333424,4.227279,0.056584,1.31034,6.687061,3.682573,...,0.594549,1.097611,0.831877,2.946731,0.475085,0.0,0.084064,0.0,0.0,0.042644
3,ACH-001538,5.08534,0.0,7.154109,2.545968,3.084064,0.0,5.868143,6.165309,4.489928,...,0.214125,0.632268,0.298658,1.641546,0.443607,0.0,0.028569,0.0,0.0,0.0
4,ACH-000242,6.729145,0.0,6.537607,2.456806,3.867896,0.799087,7.208381,5.569856,7.127014,...,1.117695,2.358959,0.084064,1.910733,0.0,0.0,0.464668,0.0,0.0,0.0


In [34]:
gene_effect.head()

Unnamed: 0_level_0,A1BG (1),A1CF (29974),A2M (2),A2ML1 (144568),A3GALT2 (127550),A4GALT (53947),A4GNT (51146),AAAS (8086),AACS (65985),AADAC (13),...,ZWILCH (55055),ZWINT (11130),ZXDA (7789),ZXDB (158586),ZXDC (79364),ZYG11A (440590),ZYG11B (79699),ZYX (7791),ZZEF1 (23140),ZZZ3 (26009)
DepMap,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ACH-000001,-0.134808,0.059764,-0.008665,-0.003572,-0.106211,-0.008257,0.018711,-0.291985,0.010921,0.064932,...,-0.037619,-0.116524,-0.029331,0.10594,0.147605,-0.119822,0.063387,0.160857,0.058648,-0.316792
ACH-000004,0.081853,-0.056401,-0.106738,-0.014499,0.078209,-0.137562,0.168657,-0.19856,0.133372,0.1513,...,-0.030901,-0.26222,0.136406,0.031327,0.093763,-0.079692,-0.173709,0.153632,0.175627,-0.040869
ACH-000005,-0.094196,-0.014598,0.100426,0.169103,0.032363,-0.14805,0.168931,-0.244777,-0.086871,-0.036037,...,0.039434,-0.336925,-0.095528,-0.035541,-0.035612,-0.040183,-0.165464,0.077343,0.019387,-0.085687
ACH-000007,-0.011544,-0.123189,0.080692,0.061046,-0.013454,-0.016922,-0.029474,-0.206516,-0.063998,0.139288,...,-0.229303,-0.463191,0.061641,0.190301,0.119388,-0.036695,-0.182449,-0.146936,-0.189451,-0.281167
ACH-000009,-0.050782,-0.037466,0.068885,0.090375,0.012634,-0.079339,-0.017808,-0.183192,0.006227,-0.0017,...,-0.157219,-0.318765,0.015761,0.196949,-0.045874,-0.186805,-0.275629,-0.001227,-0.04914,-0.240582


In [14]:
protein_links.head()

Unnamed: 0,protein1,protein2,combined_score
0,9606.ENSP00000000233,9606.ENSP00000379496,155
1,9606.ENSP00000000233,9606.ENSP00000314067,197
2,9606.ENSP00000000233,9606.ENSP00000263116,222
3,9606.ENSP00000000233,9606.ENSP00000361263,181
4,9606.ENSP00000000233,9606.ENSP00000409666,270


In [24]:
protein_info.head()

Unnamed: 0,protein,gene_ID
0,9606.ENSP00000000233,ARF5
1,9606.ENSP00000000412,M6PR
2,9606.ENSP00000001008,FKBP4
3,9606.ENSP00000001146,CYP26B1
4,9606.ENSP00000002125,NDUFAF7


## Clean datasets

In [42]:
gene_expression = gene_expression.sort_values(by = 'DepMap').reindex(sorted(gene_expression.columns), axis = 1)
gene_effect = gene_effect.sort_values(by = 'DepMap').reindex(sorted(gene_effect.columns), axis = 1)

In [43]:
def intersection(lst1, lst2):
    return list(set(lst1).intersection(lst2))

In [44]:
# Keep genes in common
gene_expression_genes = gene_expression.columns
gene_effect_genes = gene_effect.columns

intersect_expression_effect = sorted(intersection(gene_expression_genes, gene_effect_genes))
gene_expression = gene_expression[intersect_expression_effect]
gene_effect = gene_effect[intersect_expression_effect]

In [53]:
def get_gene_ID(name):
    split = name.split(' ')
    assert len(split) == 2
    return split[0]

In [67]:
#dict_rename = dict((name,get_gene_ID(name)) for name in intersect_expression_effect)
gene_expression = gene_expression.rename(columns=dict_rename)
gene_effect = gene_effect.rename(columns=dict_rename)

In [68]:
gene_ID_expr = [get_gene_ID(name) for name in intersect_expression_effect]

In [69]:
gene_ID_prot = list(protein_info['gene_ID'])

In [70]:
gene_ID_list = intersection(gene_ID_expr, gene_ID_prot)

In [71]:
gene_expression = gene_expression[gene_ID_list]
gene_effect = gene_effect[gene_ID_list]

In [74]:
gene_effect.info()

<class 'pandas.core.frame.DataFrame'>
Index: 1081 entries, ACH-000001 to ACH-002926
Columns: 16487 entries, ZBTB43 to ULK4
dtypes: float64(16487)
memory usage: 136.0+ MB
