In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import h5py
import numpy as np

In [12]:
f = h5py.File('kitchen.hdf5','r')
data = np.array(f['data_mat'])
#switch the data from signed distance field into unsigned distance field
data[data<0] = -data[data<0]
#present the shape of the experimental data
print(data[data<0].shape)
print(data[data>=0].shape)
print(data.shape)

(0,)
(68386816,)
(2087, 32, 32, 32)


In [None]:
#load the trained model

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
class SAE(nn.Module):
    def __init__(self):
        super(SAE,self).__init__()
        self.conv1 = nn.Conv3d(in_channels=1,out_channels=8,kernel_size=2,stride=2,padding=0,bias=False)
        self.conv2 = nn.Conv3d(in_channels=8,out_channels=32,kernel_size=2,stride=2,padding=0,bias=False)
        self.conv3 = nn.Conv3d(in_channels=32,out_channels=64,kernel_size=2,stride=2,padding=0,bias=False)
        self.conv4 = nn.ConvTranspose3d(in_channels=64,out_channels=32,kernel_size=2,stride=2,padding=0,bias=True)
        self.conv5 = nn.ConvTranspose3d(in_channels=32,out_channels=8,kernel_size=2,stride=2,padding=0,bias=True)
        self.conv6 = nn.ConvTranspose3d(in_channels=8,out_channels=1,kernel_size=2,stride=2,padding=0,bias=True)
    def forward(self,x):
        x1 = F.relu(self.conv1(x))
        x1 = F.relu(self.conv2(x1))
        x1 = F.relu(self.conv3(x1))
        x2 = F.relu(self.conv4(x1))
        x2 = F.relu(self.conv5(x2))
        x2 = F.relu(self.conv6(x2))
        return x1,x2

In [5]:
torch.cuda.empty_cache()

In [6]:
net = SAE().to(device)
net.load_state_dict(torch.load('UDF-AE.pkl'), strict = False)

<All keys matched successfully>

In [13]:
#choose the serial number of the baseline streamline randomly
a = 50
out_data = [a]
#set the number of streamlines after compression to 64 32 16 8 or any number you want.
#we can find that the answer to the problem when we set it at 8 is the top 8 numbers of the answer to the problem when we set it at 64.
#so we can solve the case of '8' '16' '32' and '64' at the same time when we set it at 64
#get the serial numbers of streamlines that should be retained after compression and stored them in '.npy' format.
for i in range(64-1):
    o = np.array(data[a])
    io = torch.Tensor(o).to(device).view(1,1,32,32,32)
    mo1,mo2 = net(io)
    co = np.array(mo1.data.cpu()).flatten()
    dic = {}
    for j in range(len(data)):
        m_input = torch.Tensor(np.array(data[j])).to(device).view(1,1,32,32,32)
        i_m,i_out = net(m_input)
        i_c = np.array(i_m.data.cpu()).flatten()
        ans = np.sum((co-i_c)**2)/16**2
        dic[j] = ans
    lis = sorted(dic.items(),key = lambda x:x[1])
    for k in range(len(lis)):
        kk = len(lis)-k-1
        if lis[kk][0] not in out_data:
            out_data.append(lis[kk][0])
            break
    a = out_data[-1]

In [14]:
#print the serial numbers of streamlines that should be retained after compression when we need 8 streamlines after compression
print(out_data[0:8])

[50, 1785, 1591, 337, 972, 236, 64, 1444]


In [15]:
print(out_data[0:16])

[50, 1785, 1591, 337, 972, 236, 64, 1444, 1524, 1575, 1443, 1772, 806, 778, 983, 571]


In [16]:
print(out_data[0:32])

[50, 1785, 1591, 337, 972, 236, 64, 1444, 1524, 1575, 1443, 1772, 806, 778, 983, 571, 817, 673, 639, 917, 650, 698, 1793, 567, 1538, 402, 1970, 1775, 1990, 1878, 1988, 1517]


In [17]:
print(out_data)

[50, 1785, 1591, 337, 972, 236, 64, 1444, 1524, 1575, 1443, 1772, 806, 778, 983, 571, 817, 673, 639, 917, 650, 698, 1793, 567, 1538, 402, 1970, 1775, 1990, 1878, 1988, 1517, 1983, 1547, 1979, 1879, 1961, 1676, 1910, 1658, 1551, 391, 1523, 721, 741, 542, 1975, 1642, 2001, 1540, 1996, 1387, 1995, 945, 1994, 1633, 1968, 1646, 1966, 1660, 1993, 985, 1989, 1830]


In [None]:
#cut the npy array-out_data according to the number of streamlines after compression firstly
#stored the serial numbers
out_list = np.zeros(2087)# the same length as the experimental data
for i in range(len(out_list)):
    if i in out_data:
        out_list[i] = 1
np.save('kitchen_compression_64.npy',out_data)