-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_reals.py
128 lines (107 loc) · 5.88 KB
/
run_reals.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
'''
Run this after generating samples from generate_oscillation.py!!!
python -m scripts.run_reals
'''
import glob
import os
import numpy as np
import torch
import skimage.io as skio
import scipy.io as scio
import argparse
from reals.REALS import REALS
parser = argparse.ArgumentParser()
parser.add_argument('--lr_wwt', type=float, default=1e-4, help='learning rate of W.')
parser.add_argument('--lr_tau', type=float, default=1e-2, help='learning rate of tau.')
parser.add_argument('--epoch', type=int, default=2000, help='number of iterations.')
parser.add_argument('--batch_size', type=int, default=None, help='batch size, for mini-batch training.')
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--rank', type=int, default=1, help='rank of W')
parser.add_argument('--tau', type=str, default='euclidean', help='transformation to use. affine, euclidean are avail.')
parser.add_argument('--verbose', type=bool, default=True, help='Print stuffs. (Time/Epoch/loss)')
parser.add_argument("--dir", default='./data/2d_zebrafish_brain_data', type=str, help="directory of the tif")
parser.add_argument("--type", default='normal', type=str, help="type of folder. normal/br/no.")
args = parser.parse_args()
def learn_tau_reals(Y, args, tau_3x3):
Y = Y.permute(1, 2, 0) # 512x512x60
Y /= Y.max()
data_shape = list(Y.size()) # 512x512x60
args.batch_size = data_shape[-1]
Y_res = Y.reshape(np.prod(data_shape[0:2]), data_shape[2]).permute(1, 0) # 60x(512x512)
[L_res, S_res, L_stat_res, _, model, elapsed_time] = REALS(Y_res, args, data_shape)
with torch.no_grad():
if args.tau == 'affine':
tau_bear = model.theta
elif args.tau == 'euclidean':
tau_bear = torch.zeros((data_shape[-1], 2, 3))
tau_bear[:, 0, 0] = torch.cos(model.theta_ro[:, 0])
tau_bear[:, 1, 1] = torch.cos(model.theta_ro[:, 0])
tau_bear[:, 0, 1] = -torch.sin(model.theta_ro[:, 0])
tau_bear[:, 1, 0] = torch.sin(model.theta_ro[:, 0])
tau_bear[:, 0, 2] = model.theta_tr[:, 0]
tau_bear[:, 1, 2] = model.theta_tr[:, 1]
else:
exit()
tau_bear_3x3 = torch.zeros((data_shape[-1], 3, 3))
tau_bear_3x3[:, 0:2, :] = tau_bear
tau_bear_3x3[:, 2, 2] = 1
tau_bear_comp_3x3 = torch.matmul(tau_3x3, tau_bear_3x3)
L = L_res.permute(1, 0).reshape(data_shape)
S = S_res.permute(1, 0).reshape(data_shape)
L_stat = L_stat_res.permute(1, 0).reshape(data_shape)
if data_shape[0] >= data_shape[1]:
Y_np = Y.permute(2, 0, 1).numpy()
L_np = L.permute(2, 0, 1).numpy()
S_np = S.permute(2, 0, 1).numpy()
L_stat_np = L_stat.permute(2, 0, 1).numpy()
else:
Y_np = Y.permute(2, 1, 0).numpy()
L_np = L.permute(2, 1, 0).numpy()
S_np = S.permute(2, 1, 0).numpy()
L_stat_np = L_stat.permute(2, 1, 0).numpy()
return Y_np, L_np, L_stat_np, S_np, tau_bear_comp_3x3, elapsed_time, tau_bear_3x3
def main(args):
if not os.path.exists("./results"):
os.mkdir("./results")
if not os.path.exists(f"./results/{args.type}"):
os.mkdir(f"./results/{args.type}")
folder_list = glob.glob(f'{args.dir}/{args.type}/*tr_*_ro_*')
for folder in folder_list:
root = folder.split('/')[-1] # tr_0.0_ro_0.0
result_folder = f'{args.type}/reals_{root}'
print(f'current file: {root}, result folder: {result_folder}')
if not os.path.exists(f"./results/{result_folder}"):
os.mkdir(f"./results/{result_folder}")
if not os.path.exists(f"./results/{result_folder}/Y_{root}"):
os.mkdir(f"./results/{result_folder}/Y_{root}")
if not os.path.exists(f"./results/{result_folder}/L_{root}"):
os.mkdir(f"./results/{result_folder}/L_{root}")
if not os.path.exists(f"./results/{result_folder}/S_{root}"):
os.mkdir(f"./results/{result_folder}/S_{root}")
if not os.path.exists(f"./results/{result_folder}/L_stat_{root}"):
os.mkdir(f"./results/{result_folder}/L_stat_{root}")
for i in range(5):
Y = torch.from_numpy(skio.imread(f'{folder}/Y_{root}_{i}.tif').astype(float)).float()
tau_3x3 = torch.from_numpy(scio.loadmat(f'{folder}/tau_{root}_{i}.mat')['tau'])
print(f'info about Y_{i} = size: {Y.size()}, max: {torch.max(Y)}, min: {torch.min(Y)}, type: {Y.dtype}')
print(f'info about tau_{i} = tau[0]: {tau_3x3[0]}, type: {tau_3x3.dtype}')
Y_np, L_np, L_stat_np, S_np, tau_bear_comp_3x3, elapsed_time, tau_bear_3x3 = learn_tau_reals(Y.clone(), args, tau_3x3)
Y_L_Lstat_S_np = np.concatenate((Y_np, L_np, L_stat_np, S_np), axis=2)
skio.imsave(f"./results/{result_folder}/Y_{root}/Y_{root}_{i}.tif", Y_np)
skio.imsave(f"./results/{result_folder}/L_{root}/L_{root}_{i}.tif", L_np)
skio.imsave(f"./results/{result_folder}/S_{root}/S_{root}_{i}.tif", S_np)
skio.imsave(f"./results/{result_folder}/L_stat_{root}/L_stat_{root}_{i}.tif", L_stat_np)
skio.imsave(f"./results/{result_folder}/Y_L_Lstat_S_{root}_{i}.tif", Y_L_Lstat_S_np)
tau_bear_comp = tau_bear_comp_3x3[:, 0:2, :]
scio.savemat(f'./results/{result_folder}/tau_{i}.mat', {'tau': tau_bear_comp.numpy()})
scio.savemat(f'./results/{result_folder}/time_{i}.mat', {'time': elapsed_time / 1000}) # second
def calculate_tau_std(tau):
t, _, _ = tau.shape
tau = tau.reshape((t, -1)) # tx6
tau_mu = np.tile(np.mean(tau, axis=0, keepdims=True), (t, 1)) # tx6
tau_std = np.linalg.norm(tau - tau_mu, ord=1, axis=1)
return tau_std
ai = np.mean(calculate_tau_std(tau_bear_comp.numpy()))
print(f'alignment inconsistency: {ai}')
if __name__ == "__main__":
main(args)