In [3]:
import torch 
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader, Data
import h5py
import os
import hdf5storage

os.environ["TF_CPP_MIN_LOG_LEVEL"]="3"

train_DS = []
test_DS = []
data_path = ''

# Real and test dataset split
test_nos = [9] # Use Piccadilly as the test dataset
train_nos = [i for i in range(1,20) if i not in [9, 17, 18]]
# Don't include Piccadilly (9), ArtsQuad(17) or SanFran (18) in Train dataset
print(test_nos, train_nos)

[9] [1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 19]


In [2]:
'''Load Test Data'''
#no_measurements = 17 
#filename = data_path+'data/gt_graph_random_large_outliers_real.h5'
filename = data_path+'data/gt_real_cleaned_exAQSF.h5'  
for item in test_nos:#range(no_measurements): 
    x = torch.tensor(hdf5storage.read(path='/data/'+str(item)+'/x', filename=filename, options=None), dtype=torch.float)
    xt = torch.tensor(hdf5storage.read(path='/data/'+str(item)+'/xt', filename=filename, options=None), dtype=torch.float)
    o = torch.tensor(hdf5storage.read(path='/data/'+str(item)+'/o', filename=filename, options=None), dtype=torch.float)
    edge_index = torch.tensor(hdf5storage.read(path='/data/'+str(item)+'/edge_index', filename=filename, options=None), dtype=torch.long)
    edge_attr = torch.tensor(hdf5storage.read(path='/data/'+str(item)+'/edge_feature', filename=filename, options=None), dtype=torch.float)
    y = torch.tensor(hdf5storage.read(path='/data/'+str(item)+'/y', filename=filename, options=None), dtype=torch.float)
    test_DS.append(Data(x=x, xt=xt, o=o, y=y, edge_index=edge_index.t().contiguous(), edge_attr=edge_attr)) 


In [3]:
'''Load Train Data'''
#no_measurements = 1200 
#filename = data_path+'data/gt_graph_random_large_outliers.h5'  
for item in train_nos:#range(no_measurements): 
    x = torch.tensor(hdf5storage.read(path='/data/'+str(item)+'/x', filename=filename, options=None), dtype=torch.float)
    xt = torch.tensor(hdf5storage.read(path='/data/'+str(item)+'/xt', filename=filename, options=None), dtype=torch.float)
    o = torch.tensor(hdf5storage.read(path='/data/'+str(item)+'/o', filename=filename, options=None), dtype=torch.float)
    edge_index = torch.tensor(hdf5storage.read(path='/data/'+str(item)+'/edge_index', filename=filename, options=None), dtype=torch.long)
    edge_attr = torch.tensor(hdf5storage.read(path='/data/'+str(item)+'/edge_feature', filename=filename, options=None), dtype=torch.float)
    y = torch.tensor(hdf5storage.read(path='/data/'+str(item)+'/y', filename=filename, options=None), dtype=torch.float)
    train_DS.append(Data(x=x, xt=xt, o=o, y=y, edge_index=edge_index.t().contiguous(), edge_attr=edge_attr)) 


In [4]:
def qmul(q, r):
    """
    Multiply quaternion(s) q with quaternion(s) r.
    Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
    Returns q*r as a tensor of shape (*, 4).
    """
    assert q.shape[-1] == 4
    assert r.shape[-1] == 4

    original_shape = q.shape

    # Compute outer product
    terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))

    w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
    x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
    y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
    z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
    return torch.stack((w, x, y, z), dim=1).view(original_shape)

def inv_q(q):
    """
    Inverse quaternion(s) q .
    """
    assert q.shape[-1] == 4
    original_shape = q.shape
    return torch.stack((q[:, 0], -q[:, 1], -q[:, 2], -q[:, 3]), dim=1).view(original_shape)

In [5]:
from torch.nn import Sequential as Seq, Linear, ReLU, BatchNorm1d as BN, Dropout
from torch_geometric.nn import MessagePassing

class EdgeConvRot(MessagePassing):
    def __init__(self, in_channels, edge_channels, out_channels):
        super(EdgeConvRot, self).__init__(aggr='mean', flow="target_to_source") #  "Max" aggregation.
        self.mlp0 = Seq(Linear(edge_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))

        self.mlp = Seq(Linear(2*in_channels+edge_channels, out_channels),
               ReLU(),
               Linear(out_channels, out_channels))
            
    def forward(self, x, edge_index, edge_attr):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr): 
        if x_i.size(1) > 5: 
            W = torch.cat([torch.cat([x_i, x_j], dim=1), edge_attr], dim=1)  # tmp has shape [E, 2 * in_channels]
            W = self.mlp(W) 
        else:
            W = edge_attr # torch.cat([torch.cat([x_i, x_j], dim=1), edge_attr], dim=1)  # tmp has shape [E, 2 * in_channels]            
            W = self.mlp0(W) 
        return W
            
    def propagate(self, edge_index, size, x, edge_attr):    
        row, col = edge_index
        x_i = x[row]
        x_j = x[col]
        i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0) 
        edge_out = self.message(x_i, x_j, edge_attr)
        out = scatter_(self.aggr, edge_out, edge_index[i], dim_size=size[i])
        return out, edge_out 

    


In [6]:
import torch
import torch.nn.functional as F
from torch_geometric.utils import scatter_
from torch_geometric.nn import GATConv
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax

def node_model(x, batch):
   # print(batch.shape)
    out, inverse_indices = torch.unique_consecutive(batch, return_inverse=True)
    quat_vals = x[inverse_indices] 
    q_ij = qmul(x, inv_q(quat_vals[batch]))  
    return q_ij 

def edge_model(x, edge_index):
    row, col = edge_index
    q_ij = qmul(x[col], inv_q(x[row]))  
    return q_ij 

def MLP(channels, batch_norm=True):
    return Seq(*[
        Seq(Linear(channels[i - 1], channels[i]), ReLU())
        for i in range(1, len(channels))
    ])
class EdgePred(torch.nn.Module):
    def __init__(self, in_channels, edge_channels):
        super(EdgePred, self).__init__()
        self.mlp = Seq(Linear(2*in_channels+edge_channels, 8),
                       ReLU(),
                       Linear(8, 1)) 
    def forward(self, xn, edge_index, edge_attr): 
        row, col = edge_index
        xn = torch.cat([xn[row], xn[col], edge_attr], dim=1)
        xn = self.mlp(xn) 
        return torch.sigmoid(xn) 
    
class GlobalSAModule(torch.nn.Module):
    def __init__(self, nn1, nn2):
        super(GlobalSAModule, self).__init__()
        self.nn1 = nn1
        self.nn2 = nn2

    def forward(self, x, batch): 
        xn = self.nn1(x)
      #  xn = F._max_pool1d(xn, x.size(1))
       # xn = scatter_('mean', xn, batch)
       # xn = xn[batch]  
        xn = torch.cat([xn, x], dim=1) 
     #   print(xn.shape)
      #  x = xn.unsqueeze(0).repeat(x.size(0), 1, 1) 
      #  batch = torch.arange(x.size(0), device=batch.device)
        return self.nn2(xn)
    
 

In [7]:
def update_attr(x, edge_index, edge_attr):
    row, col = edge_index
    x_i = x[row]
    x_j = inv_q(x[col])
    W=qmul(edge_attr, x_i) 
    W=qmul(x_j, W) 
    return W 


def smooth_l1_loss(input, beta=1. / 5, size_average=False):
    """
    very similar to the smooth_l1_loss from pytorch, but with
    the extra beta parameter
    """
    n = torch.abs(input)
       
    cond = n < beta
    loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
    if size_average:
        return loss.mean()
    return loss.sum()

def my_smooth_l1_loss(input, beta, alpha=0.05):
    """
    very similar to the smooth_l1_loss from pytorch, but with
    the extra beta parameter
    """
    nn = torch.sum(input ** 2, dim=1) 
    beta = torch.squeeze(beta) 
    nn = torch.mul(nn, beta) 

    return nn.sum()

class Net(torch.nn.Module):
    def __init__(self): 
        super(Net, self).__init__() 
        self.no_features = 32   # More features for large dataset 
        self.conv1 = EdgeConvRot(4, 4, self.no_features) 
        self.conv2 = EdgeConvRot(self.no_features, self.no_features+4, self.no_features)  
        self.conv3 = EdgeConvRot(2*self.no_features, 2*self.no_features, self.no_features) 
        self.conv4 = EdgeConvRot(2*self.no_features, 2*self.no_features, self.no_features) 

        self.lin01 = Linear(self.no_features, self.no_features) 
        self.lin02 = Linear(self.no_features, self.no_features) 
        self.lin1 = Linear(self.no_features, 4) 
        self.lin2 = Linear(self.no_features, 1) 
        
        self.m = torch.nn.Sigmoid() 
    def forward(self, data):
        x_org, edge_index, edge_attr, batch, beta = data.x, data.edge_index, data.edge_attr, data.batch, data.o  
        
        x1, edge_x1 = self.conv1(torch.zeros_like(x_org), edge_index, edge_attr)
        x1 = F.relu(x1)
        edge_x1 = F.relu(edge_x1)
        
        x2, edge_x2 = self.conv2(x1, edge_index, torch.cat([edge_attr, edge_x1], dim=1))
        x2 = F.relu(x2)
        edge_x2 = F.relu(edge_x2)

        x3, edge_x3 = self.conv3(torch.cat([x2, x1], dim=1), edge_index, torch.cat([edge_x2, edge_x1], dim=1))
        x3 = F.relu(x3)
        edge_x3 = F.relu(edge_x3)
        
        x4, edge_x4 = self.conv4(torch.cat([x3, x2], dim=1), edge_index, torch.cat([edge_x3, edge_x2], dim=1))
        edge_x4 = F.relu(edge_x4)
        
        out01 = self.lin01(edge_x4) 
        out02 = self.lin02(edge_x4) 
        
      #  print(out.shape) 
        edge_x = self.lin1(out01) + edge_attr
        edge_x = F.normalize(edge_x, p=2, dim=1) 
        
        return self.m(self.lin2(out02)), edge_x, beta   #x, loss1, beta   # node_model(x, batch),

In [8]:
#PATH = 'checkpoint/outliers_detect_new.pth' 
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device) 
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
#checkpoint = torch.load(PATH)
#model.load_state_dict(checkpoint['model_state_dict'])
#optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 
for g in optimizer.param_groups:
    g['lr'] = 0.0001

In [9]:
PATH = 'checkpoint/cleannet_9.pth' 
import numpy as np 
no_training = 16 #round(len(datasetR)*training_exmpl)
#no_testing = 100 
#print(no_training) 
test_loader  = DataLoader(test_DS, batch_size=1, shuffle=True,num_workers=2)
train_loader = DataLoader(train_DS, batch_size=4, shuffle=False,num_workers=2)

criterion = torch.nn.BCELoss()

model.train()
best_loss = 2000 
t = time.time() 
count = 0 
val = 14688
for epoch in range(2500):
    total_loss1 = 0 
    total_loss2 = 0 
    theta = []
    loss = 0 
    for idx, data in enumerate(train_loader):
        data_gpu = data.to(device)
        optimizer.zero_grad() 

   #     print(data_gpu.y.shape)

        out, edge_x, beta = model(data_gpu)
    
        loss1 = qmul(edge_x, inv_q(edge_model(data_gpu.y, data_gpu.edge_index)))  
        loss1 = smooth_l1_loss(loss1[:, 1:])
        
        loss2 = criterion(out, beta) 
        loss = 0.1*loss1 + 500*loss2
     #   if idx % 2 == 0: 
        loss.backward()
        optimizer.step()
       # print([idx, loss.item()])
       # time.sleep(0.01)
        
        if epoch % 2 == 0: 
           # loss1 = qmul(data_gpu.edge_attr, inv_q(edge_model(data_gpu.y, data_gpu.edge_index)))  
           # loss1 = smooth_l1_loss(loss1[:, 1:]) 
            total_loss1 = total_loss1 + loss1.item() 
            total_loss2 = total_loss2 + loss2.item() 
            
    if epoch % 10 == 0:
        for data in test_loader: 
            data_gpu = data.to(device)
            out, edge_x, beta = model(data_gpu)
            loss1 = qmul(edge_x, inv_q(edge_model(data_gpu.y, data_gpu.edge_index)))  
            loss1 = smooth_l1_loss(loss1[:, 1:])
            loss2 = criterion(out, beta) 
            total_loss1 = total_loss1 + loss1.item() 
            total_loss2 = total_loss2 + loss2.item()
            
        count = count + 1
    if epoch % 2 == 0: 
        print([epoch, "{0:.5f}".format(total_loss1/no_training), "{0:.5f}".format(total_loss2/no_training), "{0:.3f}".format(time.time() - t)])
        if epoch % 10 == 0:
            if val > total_loss1/no_training : 
                val = total_loss1/no_training 
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    }, PATH) 


[0, '18157.08716', '0.21652', '1.382']
[2, '12347.50989', '0.17247', '3.212']
[4, '11780.57544', '0.17173', '5.012']
[6, '11321.65442', '0.17103', '6.867']
[8, '10950.74097', '0.17033', '8.674']
[10, '15287.24207', '0.21238', '10.794']
[12, '10414.02124', '0.16898', '12.627']
[14, '10222.12329', '0.16831', '14.427']
[16, '10066.55115', '0.16763', '16.220']
[18, '9935.96533', '0.16699', '18.051']
[20, '14267.34375', '0.20858', '20.152']
[22, '9718.85938', '0.16573', '21.987']
[24, '9625.72656', '0.16510', '23.808']
[26, '9539.66553', '0.16445', '25.624']
[28, '9460.70056', '0.16379', '27.446']
[30, '13696.33643', '0.20483', '29.556']
[32, '9326.54382', '0.16237', '31.384']
[34, '9276.59753', '0.16160', '33.215']
[36, '9240.01685', '0.16068', '35.013']
[38, '9208.98828', '0.15971', '36.872']
[40, '13398.09058', '0.19976', '38.979']
[42, '9168.32007', '0.15753', '40.803']
[44, '9157.44080', '0.15635', '42.635']
[46, '9151.89844', '0.15522', '44.482']
[48, '9149.63599', '0.15407', '46.302'

[394, '6081.39832', '0.11394', '375.340']
[396, '6077.40503', '0.11393', '377.217']
[398, '6073.20599', '0.11392', '379.060']
[400, '9153.69067', '0.15436', '381.194']
[402, '6065.52808', '0.11391', '383.023']
[404, '6061.85095', '0.11390', '384.895']
[406, '6058.20490', '0.11389', '386.745']
[408, '6054.66638', '0.11388', '388.581']
[410, '9123.31458', '0.15436', '390.671']
[412, '6047.88403', '0.11385', '392.530']
[414, '6044.59619', '0.11384', '394.382']
[416, '6041.38025', '0.11383', '396.230']
[418, '6038.23395', '0.11382', '398.063']
[420, '9094.58856', '0.15428', '400.171']
[422, '6032.10065', '0.11379', '402.025']
[424, '6029.10193', '0.11378', '403.849']
[426, '6026.14386', '0.11376', '405.672']
[428, '6023.23676', '0.11375', '407.513']
[430, '9068.28601', '0.15416', '409.646']
[432, '6018.03577', '0.11372', '411.480']
[434, '6017.08777', '0.11370', '413.293']
[436, '6021.87024', '0.11369', '415.140']
[438, '6047.84521', '0.11368', '416.957']
[440, '9183.60510', '0.15403', '41

[786, '5594.72040', '0.10953', '745.932']
[788, '5591.01929', '0.10950', '747.773']
[790, '8378.10217', '0.14791', '749.896']
[792, '5583.48291', '0.10943', '751.746']
[794, '5579.63696', '0.10939', '753.597']
[796, '5575.73413', '0.10936', '755.414']
[798, '5571.77295', '0.10932', '757.253']
[800, '8351.46826', '0.14765', '759.360']
[802, '5563.68518', '0.10925', '761.181']
[804, '5559.54962', '0.10921', '762.996']
[806, '5555.35565', '0.10918', '764.858']
[808, '5551.10114', '0.10914', '766.699']
[810, '8323.99396', '0.14740', '768.882']
[812, '5542.76434', '0.10906', '770.708']
[814, '5538.97504', '0.10903', '772.556']
[816, '5536.24072', '0.10900', '774.390']
[818, '5536.42285', '0.10896', '776.230']
[820, '8317.98096', '0.14714', '778.326']
[822, '5561.75220', '0.10892', '780.161']
[824, '5591.09113', '0.10891', '781.966']
[826, '5606.13971', '0.10891', '783.806']
[828, '5571.18994', '0.10890', '785.621']
[830, '8278.62402', '0.14690', '787.748']
[832, '5498.83374', '0.10876', '78

[1170, '7117.96674', '0.13434', '1111.313']
[1172, '4665.35071', '0.09978', '1113.155']
[1174, '4662.69647', '0.09971', '1114.980']
[1176, '4660.67712', '0.09963', '1116.816']
[1178, '4660.38525', '0.09955', '1118.678']
[1180, '7108.56903', '0.13383', '1120.803']
[1182, '4681.94910', '0.09937', '1122.654']
[1184, '4721.51849', '0.09926', '1124.511']
[1186, '4756.40381', '0.09918', '1126.347']
[1188, '4712.84625', '0.09918', '1128.190']
[1190, '7096.63477', '0.13335', '1130.317']
[1192, '4642.57239', '0.09920', '1132.167']
[1194, '4636.02783', '0.09913', '1133.997']
[1196, '4631.86420', '0.09903', '1135.857']
[1198, '4629.99896', '0.09895', '1137.708']
[1200, '7060.49878', '0.13300', '1139.841']
[1202, '4623.97369', '0.09884', '1141.704']
[1204, '4620.92773', '0.09877', '1143.556']
[1206, '4618.13705', '0.09870', '1145.397']
[1208, '4615.45294', '0.09863', '1147.233']
[1210, '7040.46698', '0.13254', '1149.350']
[1212, '4610.20245', '0.09850', '1151.175']
[1214, '4607.60385', '0.09843', 

[1544, '4333.96539', '0.08976', '1466.658']
[1546, '4341.88446', '0.08971', '1468.507']
[1548, '4366.74985', '0.08967', '1470.352']
[1550, '6759.73239', '0.12078', '1472.505']
[1552, '4526.59344', '0.08979', '1474.351']
[1554, '4504.84796', '0.09001', '1476.226']
[1556, '4368.25458', '0.09013', '1478.064']
[1558, '4348.39639', '0.09002', '1479.933']
[1560, '6664.70132', '0.12078', '1482.066']
[1562, '4332.06567', '0.08966', '1483.892']
[1564, '4329.70648', '0.08957', '1485.752']
[1566, '4328.01675', '0.08953', '1487.605']
[1568, '4324.73956', '0.08946', '1489.423']
[1570, '6636.24353', '0.12027', '1491.523']
[1572, '4323.33774', '0.08933', '1493.352']
[1574, '4324.79105', '0.08929', '1495.218']
[1576, '4328.19336', '0.08928', '1497.033']
[1578, '4333.82550', '0.08928', '1498.848']
[1580, '6654.14264', '0.12016', '1500.967']
[1582, '4349.64059', '0.08932', '1502.819']
[1584, '4354.34070', '0.08935', '1504.648']
[1586, '4350.77142', '0.08935', '1506.492']
[1588, '4338.19662', '0.08929', 

[1918, '4097.66217', '0.08272', '1821.904']
[1920, '6336.03931', '0.11197', '1824.043']
[1922, '4094.81543', '0.08261', '1825.854']
[1924, '4093.42151', '0.08258', '1827.720']
[1926, '4092.03760', '0.08254', '1829.585']
[1928, '4090.62830', '0.08250', '1831.433']
[1930, '6327.34058', '0.11173', '1833.552']
[1932, '4087.78653', '0.08242', '1835.384']
[1934, '4086.38885', '0.08238', '1837.231']
[1936, '4085.04611', '0.08234', '1839.089']
[1938, '4083.86438', '0.08230', '1840.910']
[1940, '6318.64798', '0.11150', '1843.028']
[1942, '4084.68353', '0.08221', '1844.851']
[1944, '4092.78476', '0.08219', '1846.685']
[1946, '4124.04761', '0.08223', '1848.534']
[1948, '4223.92859', '0.08252', '1850.397']
[1950, '6660.46893', '0.11349', '1852.536']
[1952, '4392.13159', '0.08451', '1854.383']
[1954, '4098.08643', '0.08432', '1856.246']
[1956, '4109.23761', '0.08333', '1858.061']
[1958, '4100.67017', '0.08281', '1859.865']
[1960, '6296.49298', '0.11188', '1861.966']
[1962, '4070.42413', '0.08236', 

[2292, '3824.50308', '0.07753', '2177.564']
[2294, '3824.08072', '0.07743', '2179.427']
[2296, '3822.31702', '0.07736', '2181.274']
[2298, '3820.61517', '0.07727', '2183.078']
[2300, '5943.73663', '0.10577', '2185.218']
[2302, '3818.18454', '0.07718', '2187.094']
[2304, '3818.03561', '0.07713', '2188.957']
[2306, '3819.06924', '0.07711', '2190.816']
[2308, '3822.41046', '0.07710', '2192.637']
[2310, '5953.80347', '0.10569', '2194.767']
[2312, '3845.07147', '0.07719', '2196.610']
[2314, '3869.63754', '0.07733', '2198.466']
[2316, '3897.44684', '0.07756', '2200.299']
[2318, '3902.79352', '0.07778', '2202.158']
[2320, '5986.39685', '0.10651', '2204.270']
[2322, '3827.01120', '0.07752', '2206.122']
[2324, '3808.34183', '0.07723', '2207.974']
[2326, '3803.80505', '0.07705', '2209.834']
[2328, '3802.59106', '0.07696', '2211.689']
[2330, '5921.54950', '0.10539', '2213.821']
[2332, '3801.59735', '0.07686', '2215.662']
[2334, '3801.75211', '0.07681', '2217.505']
[2336, '3802.78070', '0.07679', 

In [10]:
import numpy as np 
import math 
import h5py
import time 
criterion = torch.nn.BCELoss()

data_path = './' # os.getcwd() 
PATH = 'checkpoint/cleannet_9.pth'
#PATH = 'checkpoint/outliers_detect_new.pth' 

#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = Net().to(device) 
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])

train_loader = DataLoader(train_DS, batch_size=1, shuffle=False)
test_loader = DataLoader(test_DS, batch_size=1, shuffle=False)
#test_loader = DataLoader(datasetR, batch_size=1, shuffle=False)
#model = best_model 
#print(best_loss)
#pred_rot = []
model.eval()
total_loss = 0 
count = 0 
total_time = 0
t = time.time() 
hf = h5py.File(data_path+'data/cleannet_train_9.h5', 'w')
#hf = h5py.File(data_path+'data/gt_graph_random_large_outliers_test_pred_rot.h5', 'w')
theta = [] 

for data in train_loader: 
    print(data) 
    data_gpu = data.to(device)
    out, x, beta = model(data_gpu)
  #  x = edge_model(data_gpu.xt, data_gpu.edge_index) 
   # loss = criterion(out, beta)
  #  loss = (pred - data_gpu.y).pow(2).sum() 
  #  total_loss = total_loss + loss.item() 
 #   pred_rot = torch.cat([data.xt, pred, data.y], dim=1).data.cpu().numpy()
    hf.create_dataset('/data/'+str(count+1)+'/ot', data=out.data.cpu().numpy())
    hf.create_dataset('/data/'+str(count+1)+'/o', data=beta.data.cpu().numpy())
  #  hf.create_dataset('/data/'+str(count+1)+'/onode', data=data_gpu.onode.data.cpu().numpy())
  #  hf.create_dataset('/data/'+str(count+1)+'/omarker', data=data_gpu.omarker.data.cpu().numpy())
    hf.create_dataset('/data/'+str(count+1)+'/refined_qq', data=x.data.cpu().numpy())
    hf.create_dataset('/data/'+str(count+1)+'/y', data=data_gpu.y.data.cpu().numpy())
    hf.create_dataset('/data/'+str(count+1)+'/xt', data=data_gpu.xt.data.cpu().numpy())
    hf.create_dataset('/data/'+str(count+1)+'/edge_index', data=data_gpu.edge_index.data.cpu().numpy())
    hf.create_dataset('/data/'+str(count+1)+'/edge_feature', data=data_gpu.edge_attr.data.cpu().numpy())
    count = count + 1 
   # print([len(pred_rot), (time.time()-t)/len(pred_rot)])
#print([total_loss/len(test_loader), (time.time() - t)/(test_loader.batch_size*len(test_loader))])
hf.close() 
   # print([len(pred_rot), (time.time()-t)/len(pred_rot)])
print([total_loss/len(train_loader), total_time /(test_loader.batch_size*len(train_loader))]) 


Batch(batch=[577], edge_attr=[194030, 4], edge_index=[2, 194030], o=[194030, 1], x=[577, 4], xt=[577, 4], y=[577, 4])
Batch(batch=[227], edge_attr=[40020, 4], edge_index=[2, 40020], o=[40020, 1], x=[227, 4], xt=[227, 4], y=[227, 4])
Batch(batch=[677], edge_attr=[95510, 4], edge_index=[2, 95510], o=[95510, 1], x=[677, 4], xt=[677, 4], y=[677, 4])
Batch(batch=[341], edge_attr=[47188, 4], edge_index=[2, 47188], o=[47188, 1], x=[341, 4], xt=[341, 4], y=[341, 4])
Batch(batch=[450], edge_attr=[104680, 4], edge_index=[2, 104680], o=[104680, 1], x=[450, 4], xt=[450, 4], y=[450, 4])
Batch(batch=[332], edge_attr=[41138, 4], edge_index=[2, 41138], o=[41138, 1], x=[332, 4], xt=[332, 4], y=[332, 4])
Batch(batch=[553], edge_attr=[207864, 4], edge_index=[2, 207864], o=[207864, 1], x=[553, 4], xt=[553, 4], y=[553, 4])
Batch(batch=[338], edge_attr=[49352, 4], edge_index=[2, 49352], o=[49352, 1], x=[338, 4], xt=[338, 4], y=[338, 4])
Batch(batch=[1084], edge_attr=[140162, 4], edge_index=[2, 140162], o=[1

In [11]:
hf.close() 
