/
enhance_censoring.py
152 lines (125 loc) · 4.39 KB
/
enhance_censoring.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Update a file denoting censored volumes (e.g., high-motion volumes) to censor
volumes before, after, and between.
"""
import argparse
from os.path import splitext, join, basename, dirname
import numpy as np
def _to_vec(arr):
"""
Convert bool-like 2D array to 1D vector.
Assumes TRxCensored-Vol format.
"""
if len(arr.shape) > 2:
raise Exception('Not a 2D array.')
elif len(arr.shape) == 1:
return arr
else:
vec = np.sum(arr, axis=1)
if np.any(vec > 1):
raise Exception('Boolean censoring array does not appear to be properly formatted.')
return vec
def _to_arr(vec):
"""
Convert bool-like 1D vector to 2D array.
"""
# I think this is the ugliest way of doing this, but...
if len(vec.shape) != 1:
raise Exception('Not a 1D vector.')
n_trs = vec.size
n_cens = np.sum(vec)
arr = np.zeros((n_trs, int(n_cens)), int)
cols = range(int(n_cens))
rows = np.where(vec)[0]
arr[rows, cols] = 1
return arr
def main(in_file, out_file, n_contig=2, n_before=1, n_after=2):
"""
Censor non-contiguous TRs based on outlier file.
"""
censor_data = np.loadtxt(in_file)
_, suff = splitext(in_file)
if suff == '.1D':
package = 'AFNI'
if len(censor_data.shape) != 1:
raise Exception('Not a 1D vector. Shape: {0}'.format(censor_data.shape))
censor_vec = 1 - censor_data.astype(int)
elif suff == '.txt':
package = 'FSL'
censor_vec = _to_vec(censor_data)
else:
raise Exception('Unrecognized file type {0}'.format(suff))
out_vec = np.zeros(censor_vec.shape, int)
cens_vols = np.where(censor_vec)[0]
# Flag volumes before each outlier
temp = np.copy(cens_vols)
for trs_before in range(1, n_before+1):
temp = np.hstack((temp, cens_vols-trs_before))
cens_vols = np.unique(temp)
all_vols = np.arange(len(censor_vec))
# Remove censored index outside range
# Unnecessary here but keeps everything interpretable
cens_vols = np.intersect1d(all_vols, cens_vols)
# Flag volumes after each outlier
temp = np.copy(cens_vols)
for trs_after in range(1, n_after+1):
temp = np.hstack((temp, cens_vols+trs_after))
cens_vols = np.unique(temp)
all_vols = np.arange(len(censor_vec))
# Remove censored index outside range
cens_vols = np.intersect1d(all_vols, cens_vols)
# Flag orphan volumes (unflagged volumes between flagged ones)
temp = np.copy(cens_vols)
contig_idx = np.where(np.diff(cens_vols) < n_contig)[0]
for idx in contig_idx:
start = cens_vols[idx]
end = cens_vols[idx+1]
temp = np.hstack((temp, np.arange(start, end)))
cens_vols = np.unique(temp)
# Create improved censor vector
out_vec[cens_vols] = 1
if package == 'AFNI':
out_data = 1 - out_vec
elif package == 'FSL':
out_data = _to_arr(out_vec)
np.savetxt(out_file, out_data, fmt='%i', delimiter='\t')
def _get_parser():
"""
Argument parser for enhance_censoring
"""
parser = argparse.ArgumentParser(description='Enhance censoring.')
# Required arguments
parser.add_argument('in_file',
type=str,
help='1D or txt file containing censoring index')
parser.add_argument('out_file',
type=str,
help='Output file')
# Optional arguments
parser.add_argument('--between',
dest='n_contig',
action='store',
type=int,
help=('Number of volumes between outliers to censor'),
default=2)
parser.add_argument('--pre',
dest='n_before',
action='store',
type=int,
help=('Number of volumes before outliers to censor'),
default=1)
parser.add_argument('--post',
dest='n_after',
action='store',
type=int,
help=('Number of volumes after outliers to censor'),
default=2)
return parser
def _main(argv=None):
"""
Compile arguments for showxcorrx workflow.
"""
args = vars(_get_parser().parse_args(argv))
main(**args)
if __name__ == '__main__':
_main()