-
Notifications
You must be signed in to change notification settings - Fork 27
/
utils.py
220 lines (153 loc) · 5.46 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
from __future__ import division
import numpy as np
from emcee.utils import MPIPool
try:
from mpi4py import MPI
MPI=MPI
default_op = MPI.SUM
except ImportError:
MPI=None
default_op=None
print("Warning! mpi4py installation not found or broken!")
#####################################################################################
#######Supplying to the lack of rfftfreq implementation in numpy<1.8#################
#####################################################################################
def rfftfreq(n, d=1.0):
if not (isinstance(n,int) or isinstance(n, integer)):
raise ValueError("n should be an integer")
val = 1.0/(n*d)
N = n//2 + 1
results = np.arange(0, N, dtype=int)
return results * val
###########################################################################
###########Hack to make scipy interpolate objects pickleable###############
###########################################################################
class _interpolate_wrapper(object):
def __init__(self,f,args,kwargs):
self.f = f
self.args = args
self.kwargs = kwargs
def __call__(self):
try:
return self.f(*self.args,**self.kwargs)
except:
import traceback
print("lenstools: Exception while building the interpolators")
print(" exception:")
traceback.print_exc()
raise
#################################################################################################
#############################Principal Component Analysis handler################################
#################################################################################################
def pca_transform(data,pca,n_components):
assert n_components<=pca.components_.shape[0]
return pca.transform(data).T[:n_components].T
class pcaHandler(object):
"""
Handles principal component analysis
"""
def fit(self,data):
#Scale the data to zero mean and unit variance
self._pca_mean = data.mean(0)
self._pca_std = data.std(0)
self._data_scaled = data.copy()
self._data_scaled -= self._pca_mean[None]
self._data_scaled /= self._pca_std[None]
self._data_scaled /= np.sqrt(self._data_scaled.shape[0] - 1)
#Perform singular value decomposition
left,eigenvalues,right = np.linalg.svd(self._data_scaled,full_matrices=False)
#Assign eigenvalues and eigenvectors as attributes
self.components_ = right
self.explained_variance_ = eigenvalues**2
@property
def eigenvalues(self):
return self.explained_variance_
@property
def eigenvectors(self):
return self.components_*np.sqrt(self._data_scaled.shape[0] - 1)*self._pca_std[None] + self._pca_mean[None]
def transform(self,X):
#Cast X to the right dimensions
if len(X.shape)==1:
X_copy = X.copy()[None]
else:
X_copy = X.copy()
#Subtract mean and scale by variance
X_copy -= self._pca_mean[None]
X_copy /= (self._pca_std[None]*np.sqrt(self._data_scaled.shape[0] - 1))
#Compute the projection via dot product
components = X_copy.dot(self.components_.transpose())
if len(X.shape)==1:
return components[0]
else:
return components
def inverse_transform(self,X,n_components=None):
#Cast X to the right dimensions
if len(X.shape)==1:
X_copy = X.copy()[None]
else:
X_copy = X.copy()
#Use the PCA basis vectors to project back to the original space
if n_components is not None:
basis_vectors = self.components_[:n_components]
X_copy = X_copy[:,:n_components]
else:
basis_vectors = self.components_
#Original space
original_components = X_copy.dot(basis_vectors)
#De-whitening
original_components *= (self._pca_std[None]*np.sqrt(self._data_scaled.shape[0] - 1))
original_components += self._pca_mean[None]
if original_components.shape[0]==1:
return original_components[0]
else:
return original_components
def select_components(self,X,n_components):
all_components = self.transform(X)
return self.inverse_transform(all_components,n_components=n_components)
#################################################################################################
###################MPIWhirlPool: should handle one sided communications too######################
#################################################################################################
class MPIWhirlPool(MPIPool):
"""
MPI class handler, inherits from MPI pool and adds one sided communications utilities (using RMA windows)
"""
def openWindow(self,memory):
"""
Create a RMA window that looks from the master process onto all the other workers
:param memory: memory buffer on which to open the window
:type memory: numpy nd array
"""
#Stats of the memory to open a window onto
assert isinstance(memory,np.ndarray)
self.memory = memory
#Create the window
self.win = MPI.Win.Create(memory=memory,comm=self.comm)
self.win.Fence()
def get(self,process):
"""
Read data from an RMA window open on a particular process
"""
read_buffer = np.zeros(self.memory.shape,dtype=self.memory.dtype)
self.win.Fence()
if self.is_master():
self.win.Get(read_buffer,process)
self.win.Fence()
if self.is_master():
return read_buffer
else:
return None
def accumulate(self,op=default_op):
"""
Accumulates the all the window data on the master, performing a custom operation (default is sum)
"""
for n in range(1,self.size+1):
self.win.Fence()
if(self.rank==n):
self.win.Accumulate(self.memory,0,op=op)
self.win.Fence()
def closeWindow(self):
"""
Closes a previously opened RMA window
"""
self.win.Fence()
self.win.Free()