-
Notifications
You must be signed in to change notification settings - Fork 0
/
ica_factorization.py
57 lines (46 loc) · 1.71 KB
/
ica_factorization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import argparse
import time
import torch
from sklearn.decomposition import FastICA
import numpy as np
import random
def independent_components_decomposition(W, n_components):
fast_ica = FastICA(n_components=n_components)
fast_ica.fit(W)
W_ = fast_ica.components_
norm = np.linalg.norm(W_, axis = 1).reshape(-1, n_components)
W_nomralize = W_ / norm.T
independent_components = torch.from_numpy(W_nomralize.T).float()
return independent_components
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extract factor/eigenvectors of latent spaces using closed form factorization"
)
parser.add_argument(
"-n", "--number_of_component", type=int, default=8, help="index of eigenvector"
)
parser.add_argument(
"--out", type=str, default="factor.pt", help="name of the result factor file"
)
parser.add_argument("ckpt", type=str, help="name of the model checkpoint")
parser.add_argument('--full_model', default=False, action='store_true')
args = parser.parse_args()
if args.full_model:
state_dict = torch.load(args.ckpt).state_dict()
else:
state_dict = torch.load(args.ckpt)["g_ema"]
modulate = {
k: v
for k, v in state_dict.items()
if "modulation" in k and "to_rgbs" not in k and "weight" in k
}
weight_mat = []
for k, v in modulate.items():
weight_mat.append(v)
W = torch.cat(weight_mat, 0)
np.random.seed(0)
random.seed(0)
s = time.process_time()
eigvec = independent_components_decomposition(W, args.number_of_component).to("cpu")
print(time.process_time() - s)
torch.save({"ckpt": args.ckpt, "eigvec": eigvec}, args.out)