-
Notifications
You must be signed in to change notification settings - Fork 0
/
tm_average
executable file
·231 lines (172 loc) · 8.4 KB
/
tm_average
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#!/usr/bin/env python
import random
import time
import uuid
import os
from collections import Counter
from pprint import pprint
import json
import shutil
import logging
import numpy as np
from tomominer.core import read_mrc, write_mrc
from tomominer.align.runners import one_vs_all_alignment
from tomominer.average.runners import volume_average
from tomominer.filtering import gaussian_smoothing, lowpass_smoothing
def parse_data(conf):
"""
Parse data on subtomogram locations, and transformations.
:param conf: A configuration string in JSON format.
The JSON object is assumed to be a list of dictionaries. Each dictionary
is checked for four keys. subtomogram, mask, angle, and loc. The
subtomogram is the only required field. It must give a disk location for
loading the subtomogram. The mask is also a file location but is optional.
The angle and loc are both arrays of length 3. The angle is parsed as a
ZYZ Euler angle. The loc is the displacement shift for alignment.
Any missing data is filled in, and a list data structure, with a list of
tuples for each record is returned.
"""
subs = []
for record in conf:
if 'subtomogram' not in record:
raise Exception(repr(record))
vol_path = record['subtomogram']
if 'mask' not in record:
raise Exception(repr(record))
mask_path = record['mask']
if 'angle' in record:
if len(record['angle']) != 3:
raise
ang = np.array([float(_) for _ in record['angle']])
else:
ang = np.zeros(3, dtype=np.float)
if "loc" in record:
if len(record["loc"]) != 3:
raise
loc = np.array([float(_) for _ in record["loc"]])
else:
loc = np.zeros(3, dtype=np.float)
subs.append((str(vol_path), str(mask_path), ang, loc))
return subs
def average(vmal, tmp_dir, host, port, use_fft, L, N, smoothing):
"""
Average Pipeline.
:param vmal: A list of subtomograms to classify. The list is of tuples
with four values (volume, mask, angle, disp). The volume and mask entries
are strings which give filesystem locations. The angle is a ZYZ euler
angle rotation, and the displacement is a 3-component vector describing the
shift. The angle and displacement are an initial estimate of the required
transformations necessary for alignment.
:param tmp_dir: The location where we will store intermediate results.
:param host: The hostname/IP address of the computer running the queue
server. We will connect to this and use it's pool of processors for
computation.
:param port: The port to connect to the queue host on.
:param use_fft: Use FFT-based averaging.
:param L: Sampling frequency for angular search (2*pi/L is sampling).
:param N: Number of iterations of averaging to run.
:param smoothing: If given, the type of smoothing to apply on each iteration.
:todo: rewrite to shortcut reading the file if we are going to skip to the
next iteration. startup on 100k with 23 rounds done takes too long.
"""
start_time = time.time()
v = read_mrc(vmal[0][0])
vol_shape = v.shape
vm = [(v[0],v[1]) for v in vmal]
for p in range(1, N+1):
pass_dir = os.path.join(tmp_dir, 'pass_%03d' % (p,))
if not os.path.exists(pass_dir):
os.mkdir(pass_dir)
json_file = os.path.join(pass_dir, 'data_config_%03d.json' % (p,))
if os.path.exists(json_file):
logging.debug("Loading existing JSON Data file: %s and skipping to next pass.", json_file)
with open(json_file) as f:
json_data = json.load(f)
vmal = parse_data(json_data)
# use the vmal in this round and skip to the next iteration.
continue
# make sure that vmal does not contain duplicates:
counter = Counter(v[0] for v in vmal)
bad_list = [x[0] for x in counter.most_common() if x[1] > 1]
if bad_list:
raise Exception("Data contains duplicate tomograms: %s" % (bad_list))
vol_key, mask_key = volume_average(host, port, vmal, vol_shape, pass_dir, use_fft)
new_vol_key = os.path.join(pass_dir, 'template_vol_%03d_%03d.mrc' % (p,0))
new_mask_key = os.path.join(pass_dir, 'template_mask_%03d_%03d.mrc' % (p,0))
shutil.move(vol_key, new_vol_key)
shutil.move(mask_key, new_mask_key)
avg = (new_vol_key, new_mask_key)
# If we smooth the average, do so, and use that as the alignment target.
if smoothing:
if smoothing == 'gaussian':
# TODO: finish implementation.
smoothed_vol_key = gaussian_smoothing(vol_key, smoothed_vol_key)
else:
smoothed_vol_key = lowpass_smoothing(vol_key, smoothed_vol_key)
new_smoothed_vol_key = os.path.join(pass_dir, 'template_smoothed_vol_%03d_%03d.mrc' % (p,0))
shutil.move(smoothed_vol_key, new_smoothed_vol_key)
avg = (new_smoothed_vol_key, new_mask_key)
# Align each volume against average.
# For each data entry, we will compute the best matching template. We
# will collect the best scores, and save the transformations that lead
# to that score.
results = one_vs_all_alignment(host, port, avg, vm, L)
# results is an array in the same order as vmal.
bad_cnt = 0
json_data = []
for v,r in zip(vmal, results):
score, loc, ang = r
if score == 0.0:
bad_cnt += 1
json_data.append(dict(subtomogram = v[0],
mask = v[1],
angle = ang,
loc = loc,
template = avg[0],
label = 0))
if bad_cnt > 0:
logging.warn("Bad transformations for %s of %s" % (bad_cnt, len(json_data)))
write_json_data(json_data, json_file)
vmal = parse_data(json_data)
return json_data
def write_json_data(data, json_file):
"""
:todo: split out into io.py with a read_json_data function. Rename to read/write_vmal and add validation of json structure.
"""
json_data = []
for d in data:
# numpy arrays are not pickle-able. Convert to a list of numbers.
d['angle'] = tuple(d['angle'])
d['loc'] = tuple(d['loc'])
json_data.append(d)
for d in json_data:
try:
json.dumps(d)
except Exception, e:
print "Failed to dump record:"
print d
print e
# write new data to disk
with open(json_file, 'w') as f:
json.dump(json_data, f, indent=2)
if __name__ == '__main__':
import sys
import logging
import argparse
parser = argparse.ArgumentParser(description="TomoMiner Averaging")
parser.add_argument( '--host', default="127.0.0.1", type=str, help="Address of TomoMiner server (default 127.0.0.1)")
parser.add_argument('-p', '--port', default=5011, type=int, help="Port of TomoMiner server (default 5011)")
parser.add_argument('-d', '--data', default="data.json", type=str, help="Data VMAL file")
parser.add_argument('-t', '--tmp_dir', default="/tmp/", type=str, help="Temporary file location.")
parser.add_argument('-L', default=36, type=int, help="Angular resolution parameter (2*pi/L is sampling).")
parser.add_argument('-n', '--iterations', default=20, type=int, help="Number of averaging rounds to run.")
parser.add_argument("-v", "--verbose", dest="verbose_count", action="count", default=0, help="set verbosity")
parser.add_argument("--smoothing", choice=['gaussian', 'lowpass'], help="Type of smoothing to apply, if any.", default=None)
parser.add_argument("--fft", action='store_true', help="Use FFT-space averaging")
args = parser.parse_args()
logging.basicConfig(level=max(3 - args.verbose_count, 0) * 10,
format='%(asctime)-15s %(name)-10s %(levelname)-8s %(message)s')
with open(args.data) as f:
conf = json.load(f)
data = parse_data(conf)
average(data, args.tmp_dir, args.host, args.port, args.fft, args.L, args.iterations, args.smoothing)