/
weight_clustmap.py
executable file
·61 lines (48 loc) · 1.6 KB
/
weight_clustmap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/python
# load weights from npy
# matplotlib.pyplot.vis()
# seaborn clustering and plotting
import sys
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg') # for plotting without GUI
import matplotlib.pyplot as plt
import seaborn as sns
def random_subset_arr(arr, m_max, n_max):
[m, n] = arr.shape
m_reduce = min(m, m_max)
n_reduce = min(n, n_max)
np.random.seed(1201)
row_rand_idx = np.random.choice(m, m_reduce, replace=False)
col_rand_idx = np.random.choice(n, n_reduce, replace=False)
np.random.seed()
arr_sub = arr[row_rand_idx][:, col_rand_idx]
print('matrix from [{},{}] to a random subset of [{},{}]'.
format(m, n, arr_sub.shape[0], arr_sub.shape[1]))
return arr_sub
# read cmd
if len(sys.argv) != 3:
print('\n\nusage: <weights_visualization.py> <w_name.npy> <out_tag>')
print(sys.argv)
raise Exception('cmd error')
in_name = sys.argv[1]
tag = sys.argv[2]
print('usage:', sys.argv)
# read data
arr = np.load(in_name)
[m, n] = arr.shape
# exclude saved bias files
if (m == 1 or n == 1):
raise Exception('Not matrix, but vector, so skipped')
print('matrix sample', arr[0:2, 0:2])
print('matrix shape:', arr.shape)
# exclude large matrix
m_max = 1000
n_max = 1000
if (m > m_max or n > n_max):
print('matrix too large, down-sample to 1000 max each dim')
arr = random_subset_arr(arr, m_max, n_max)
# seaborn clustering (the rows are rows, columns are columns in clustmap)
heatmap = sns.clustermap(arr, method='average', cmap="summer", robust=True)
heatmap.savefig(in_name+'.'+tag+'.png', bbox_inches='tight')