-
Notifications
You must be signed in to change notification settings - Fork 0
/
tmp.py
58 lines (31 loc) · 741 Bytes
/
tmp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import time
a = torch.zeros(1,10)
b = torch.rand(10,1)
# for i in range(len(b)):
# b[i] = i
len_b = len(b)
c = torch.sub(b.transpose(1,0)[:,:len_b-1], b[:len_b-3,:])
print(c.shape)
b = b.transpose(1,0)[:,1:len_b]
e = b.clone()
for i in range(len_b-4):
e = torch.cat((e,b),0)
print(e.shape)
print(e)
print(b)
d = torch.where(c<=0.01, 1,0 )
f = torch.where(c<=0.01, torch.FloatTensor(e), torch.zeros(1,1))
print(f)
print(d)
print(torch.triu(d))
print(torch.triu(d).sum(1))
print(torch.argmax(torch.triu(d).sum(1)))
max_idx = torch.argmax(torch.triu(d).sum(1))
print(torch.triu(f))
g = torch.triu(f)[max_idx]
print(g)
h = torch.nonzero(g)
print(h)
i = g[h]
print(i.transpose(0,1)[0])