/
algo.py
231 lines (189 loc) · 9.29 KB
/
algo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""
Top-level functions.
"""
import pandas as pd
from distributed import Client, LocalCluster
from arboreto.core import create_graph, SGBM_KWARGS, RF_KWARGS, EARLY_STOP_WINDOW_LENGTH
def grnboost2(expression_data,
gene_names=None,
tf_names='all',
client_or_address='local',
early_stop_window_length=EARLY_STOP_WINDOW_LENGTH,
limit=None,
seed=None,
verbose=False):
"""
Launch arboreto with [GRNBoost2] profile.
:param expression_data: one of:
* a pandas DataFrame (rows=observations, columns=genes)
* a dense 2D numpy.ndarray
* a sparse scipy.sparse.csc_matrix
:param gene_names: optional list of gene names (strings). Required when a (dense or sparse) matrix is passed as
'expression_data' instead of a DataFrame.
:param tf_names: optional list of transcription factors. If None or 'all', the list of gene_names will be used.
:param client_or_address: one of:
* None or 'local': a new Client(LocalCluster()) will be used to perform the computation.
* string address: a new Client(address) will be used to perform the computation.
* a Client instance: the specified Client instance will be used to perform the computation.
:param early_stop_window_length: early stop window length. Default 25.
:param limit: optional number (int) of top regulatory links to return. Default None.
:param seed: optional random seed for the regressors. Default None.
:param verbose: print info.
:return: a pandas DataFrame['TF', 'target', 'importance'] representing the inferred gene regulatory links.
"""
return diy(expression_data=expression_data, regressor_type='GBM', regressor_kwargs=SGBM_KWARGS,
gene_names=gene_names, tf_names=tf_names, client_or_address=client_or_address,
early_stop_window_length=early_stop_window_length, limit=limit, seed=seed, verbose=verbose)
def genie3(expression_data,
gene_names=None,
tf_names='all',
client_or_address='local',
limit=None,
seed=None,
verbose=False):
"""
Launch arboreto with [GENIE3] profile.
:param expression_data: one of:
* a pandas DataFrame (rows=observations, columns=genes)
* a dense 2D numpy.ndarray
* a sparse scipy.sparse.csc_matrix
:param gene_names: optional list of gene names (strings). Required when a (dense or sparse) matrix is passed as
'expression_data' instead of a DataFrame.
:param tf_names: optional list of transcription factors. If None or 'all', the list of gene_names will be used.
:param client_or_address: one of:
* None or 'local': a new Client(LocalCluster()) will be used to perform the computation.
* string address: a new Client(address) will be used to perform the computation.
* a Client instance: the specified Client instance will be used to perform the computation.
:param limit: optional number (int) of top regulatory links to return. Default None.
:param seed: optional random seed for the regressors. Default None.
:param verbose: print info.
:return: a pandas DataFrame['TF', 'target', 'importance'] representing the inferred gene regulatory links.
"""
return diy(expression_data=expression_data, regressor_type='RF', regressor_kwargs=RF_KWARGS,
gene_names=gene_names, tf_names=tf_names, client_or_address=client_or_address,
limit=limit, seed=seed, verbose=verbose)
def diy(expression_data,
regressor_type,
regressor_kwargs,
gene_names=None,
tf_names='all',
client_or_address='local',
early_stop_window_length=EARLY_STOP_WINDOW_LENGTH,
limit=None,
seed=None,
verbose=False):
"""
:param expression_data: one of:
* a pandas DataFrame (rows=observations, columns=genes)
* a dense 2D numpy.ndarray
* a sparse scipy.sparse.csc_matrix
:param regressor_type: string. One of: 'RF', 'GBM', 'ET'. Case insensitive.
:param regressor_kwargs: a dictionary of key-value pairs that configures the regressor.
:param gene_names: optional list of gene names (strings). Required when a (dense or sparse) matrix is passed as
'expression_data' instead of a DataFrame.
:param tf_names: optional list of transcription factors. If None or 'all', the list of gene_names will be used.
:param early_stop_window_length: early stopping window length.
:param client_or_address: one of:
* None or 'local': a new Client(LocalCluster()) will be used to perform the computation.
* string address: a new Client(address) will be used to perform the computation.
* a Client instance: the specified Client instance will be used to perform the computation.
:param limit: optional number (int) of top regulatory links to return. Default None.
:param seed: optional random seed for the regressors. Default 666. Use None for random seed.
:param verbose: print info.
:return: a pandas DataFrame['TF', 'target', 'importance'] representing the inferred gene regulatory links.
"""
if verbose:
print('preparing dask client')
client, shutdown_callback = _prepare_client(client_or_address)
try:
if verbose:
print('parsing input')
expression_matrix, gene_names, tf_names = _prepare_input(expression_data, gene_names, tf_names)
if verbose:
print('creating dask graph')
graph = create_graph(expression_matrix,
gene_names,
tf_names,
client=client,
regressor_type=regressor_type,
regressor_kwargs=regressor_kwargs,
early_stop_window_length=early_stop_window_length,
limit=limit,
seed=seed)
if verbose:
print('{} partitions'.format(graph.npartitions))
print('computing dask graph')
return client \
.compute(graph, sync=True) \
.sort_values(by='importance', ascending=False)
finally:
shutdown_callback(verbose)
if verbose:
print('finished')
def _prepare_client(client_or_address):
"""
:param client_or_address: one of:
* None
* verbatim: 'local'
* string address
* a Client instance
:return: a tuple: (Client instance, shutdown callback function).
:raises: ValueError if no valid client input was provided.
"""
if client_or_address is None or str(client_or_address).lower() == 'local':
local_cluster = LocalCluster(diagnostics_port=None)
client = Client(local_cluster)
def close_client_and_local_cluster(verbose=False):
if verbose:
print('shutting down client and local cluster')
client.close()
local_cluster.close()
return client, close_client_and_local_cluster
elif isinstance(client_or_address, str) and client_or_address.lower() != 'local':
client = Client(client_or_address)
def close_client(verbose=False):
if verbose:
print('shutting down client')
client.close()
return client, close_client
elif isinstance(client_or_address, Client):
def close_dummy(verbose=False):
if verbose:
print('not shutting down client, client was created externally')
return None
return client_or_address, close_dummy
else:
raise ValueError("Invalid client specified {}".format(str(client_or_address)))
def _prepare_input(expression_data,
gene_names,
tf_names):
"""
Wrangle the inputs into the correct formats.
:param expression_data: one of:
* a pandas DataFrame (rows=observations, columns=genes)
* a dense 2D numpy.ndarray
* a sparse scipy.sparse.csc_matrix
:param gene_names: optional list of gene names (strings).
Required when a (dense or sparse) matrix is passed as 'expression_data' instead of a DataFrame.
:param tf_names: optional list of transcription factors. If None or 'all', the list of gene_names will be used.
:return: a triple of:
1. a np.ndarray or scipy.sparse.csc_matrix
2. a list of gene name strings
3. a list of transcription factor name strings.
"""
if isinstance(expression_data, pd.DataFrame):
expression_matrix = expression_data.to_numpy()
gene_names = list(expression_data.columns)
else:
expression_matrix = expression_data
assert expression_matrix.shape[1] == len(gene_names)
if tf_names is None:
tf_names = gene_names
elif tf_names == 'all':
tf_names = gene_names
else:
if len(tf_names) == 0:
raise ValueError('Specified tf_names is empty')
if not set(gene_names).intersection(set(tf_names)):
raise ValueError('Intersection of gene_names and tf_names is empty.')
return expression_matrix, gene_names, tf_names