This script will use the 32 samples generated by get_samples.ipynb to implement DataLoader.

In [96]:
import numpy as np
import sys
import h5py
import os
from __future__ import print_function, division
import torch
import pandas as pd
import numpy as np

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
%matplotlib inline

from torch.utils.data import Dataset, DataLoader

### Getting the file size of the samples

In [7]:
os.path.getsize('../data/sample_32.h5') / 1e6 # in MBs

268.437504

### Reading the data

In [39]:
f = h5py.File('../data/sample_32.h5', 'r')
list(f.keys())

['sample32']

In [40]:
delta_HI = f['sample32']
f['sample32'].shape

(32, 128, 128, 128)

In [28]:
f.close()

In [51]:
delta_HI.shape

(32, 128, 128, 128)

In [49]:
delta_HI[0].shape

(128, 128, 128)

In [47]:
for sample in delta_HI:
    print(sample)

[[[2.7504480e-03 1.0773394e-02 6.7715812e-03 ... 1.3569459e+00
   1.0198050e+00 8.8540632e-01]
  [7.3223263e-03 7.8354850e-03 4.2035379e-02 ... 1.2797499e+00
   1.5884297e+00 1.7761337e+00]
  [6.9693945e-02 5.6814358e-02 2.5016809e-01 ... 3.3856087e+00
   3.5277927e+00 2.4171164e+00]
  ...
  [4.0532215e+01 3.8151985e+01 2.3695166e+01 ... 0.0000000e+00
   0.0000000e+00 5.9972785e-02]
  [2.3356449e+01 2.7734932e+01 5.1065240e+00 ... 0.0000000e+00
   0.0000000e+00 1.6837816e-01]
  [2.1791283e+01 2.1270782e+01 7.6947075e-01 ... 7.7227759e-04
   0.0000000e+00 0.0000000e+00]]

 [[1.6299952e-02 1.1533336e-02 1.6717875e-02 ... 2.3383579e+00
   1.1995572e+00 1.1620615e+00]
  [1.8768493e-02 2.8532380e-02 1.5515259e-02 ... 3.0358357e+00
   2.7524686e+00 1.7740120e+00]
  [7.3735036e-02 1.0783246e-01 5.4595791e-02 ... 4.4480214e+00
   2.5631208e+00 1.3061752e+00]
  ...
  [5.2404919e+01 2.6517214e+01 1.2370697e+01 ... 1.4021972e-01
   1.5950106e-01 9.8013384e-03]
  [3.2057415e+01 4.0431886e+00 1.618

### Dataset class implementation

In [151]:
class HydrogenDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, h5_file, root_dir):
        """
        Args:
            h5_file (string): name of the h5 file with 32 sampled cubes.
            root_dir (string): Directory with the .h5 file.
        """
        file_size = os.path.getsize(root_dir + h5_file) / 1e6 # in MBs
        print("The file size is " + str(int(file_size)) + " MBs")
        
        # self.subcubes = h5py.File('../data/sample_32.h5', 'r')
        self.subcubes = h5py.File(root_dir + h5_file, 'r')['sample32']
        self.h5_file = h5_file
        self.root_dir = root_dir

    def __len__(self):
        # Function called when len(self) is executed
        
        #print(len(self.subcubes))
        return len(self.subcubes)

    def __getitem__(self, idx):
        """
        This can be implemented in such a way that the whole h5 file read 
        using h5py.File() and get_sample() function is called to return
        a random subcube. This won't increase memory usage because the
        subcubes will be read in the same way and only the batch will
        be read into memory.
        
        Here we have implemented it so that it can be used with data
        generated by get_sample() function.
        
        The output of this function is one subcube with the dimensions
        specified by get_sample() implementation.
        """
        
        # default version -> error in training because of dimensions
        #sample = self.subcubes[idx]
        
        # reshaped version to add another dimension
        sample = self.subcubes[idx].reshape((1,128,128,128))

        return sample

In [152]:
sampled_subcubes = HydrogenDataset(h5_file="sample_32.h5",
                                    root_dir = "../data/")
sampled_subcubes

The file size is 268 MBs


<__main__.HydrogenDataset at 0x1b14517f0>

In [153]:
len(sampled_subcubes)

32

In [149]:
sampled_subcubes[0]

array([[[[2.7504480e-03, 1.0773394e-02, 6.7715812e-03, ...,
          1.3569459e+00, 1.0198050e+00, 8.8540632e-01],
         [7.3223263e-03, 7.8354850e-03, 4.2035379e-02, ...,
          1.2797499e+00, 1.5884297e+00, 1.7761337e+00],
         [6.9693945e-02, 5.6814358e-02, 2.5016809e-01, ...,
          3.3856087e+00, 3.5277927e+00, 2.4171164e+00],
         ...,
         [4.0532215e+01, 3.8151985e+01, 2.3695166e+01, ...,
          0.0000000e+00, 0.0000000e+00, 5.9972785e-02],
         [2.3356449e+01, 2.7734932e+01, 5.1065240e+00, ...,
          0.0000000e+00, 0.0000000e+00, 1.6837816e-01],
         [2.1791283e+01, 2.1270782e+01, 7.6947075e-01, ...,
          7.7227759e-04, 0.0000000e+00, 0.0000000e+00]],

        [[1.6299952e-02, 1.1533336e-02, 1.6717875e-02, ...,
          2.3383579e+00, 1.1995572e+00, 1.1620615e+00],
         [1.8768493e-02, 2.8532380e-02, 1.5515259e-02, ...,
          3.0358357e+00, 2.7524686e+00, 1.7740120e+00],
         [7.3735036e-02, 1.0783246e-01, 5.4595791e-02, .

In [150]:
sampled_subcubes[0].shape

(1, 128, 128, 128)

In [144]:
sampled_subcubes[0].reshape((1,128,128,128)).shape

(1, 128, 128, 128)

In [145]:
sampled_subcubes[0].reshape((1,128,128,128))

array([[[2.7504480e-03, 1.0773394e-02, 6.7715812e-03, ...,
         1.3569459e+00, 1.0198050e+00, 8.8540632e-01],
        [7.3223263e-03, 7.8354850e-03, 4.2035379e-02, ...,
         1.2797499e+00, 1.5884297e+00, 1.7761337e+00],
        [6.9693945e-02, 5.6814358e-02, 2.5016809e-01, ...,
         3.3856087e+00, 3.5277927e+00, 2.4171164e+00],
        ...,
        [4.0532215e+01, 3.8151985e+01, 2.3695166e+01, ...,
         0.0000000e+00, 0.0000000e+00, 5.9972785e-02],
        [2.3356449e+01, 2.7734932e+01, 5.1065240e+00, ...,
         0.0000000e+00, 0.0000000e+00, 1.6837816e-01],
        [2.1791283e+01, 2.1270782e+01, 7.6947075e-01, ...,
         7.7227759e-04, 0.0000000e+00, 0.0000000e+00]],

       [[1.6299952e-02, 1.1533336e-02, 1.6717875e-02, ...,
         2.3383579e+00, 1.1995572e+00, 1.1620615e+00],
        [1.8768493e-02, 2.8532380e-02, 1.5515259e-02, ...,
         3.0358357e+00, 2.7524686e+00, 1.7740120e+00],
        [7.3735036e-02, 1.0783246e-01, 5.4595791e-02, ...,
         4.448

In [103]:
sampled_subcubes[0][0]

array([[2.7504480e-03, 1.0773394e-02, 6.7715812e-03, ..., 1.3569459e+00,
        1.0198050e+00, 8.8540632e-01],
       [7.3223263e-03, 7.8354850e-03, 4.2035379e-02, ..., 1.2797499e+00,
        1.5884297e+00, 1.7761337e+00],
       [6.9693945e-02, 5.6814358e-02, 2.5016809e-01, ..., 3.3856087e+00,
        3.5277927e+00, 2.4171164e+00],
       ...,
       [4.0532215e+01, 3.8151985e+01, 2.3695166e+01, ..., 0.0000000e+00,
        0.0000000e+00, 5.9972785e-02],
       [2.3356449e+01, 2.7734932e+01, 5.1065240e+00, ..., 0.0000000e+00,
        0.0000000e+00, 1.6837816e-01],
       [2.1791283e+01, 2.1270782e+01, 7.6947075e-01, ..., 7.7227759e-04,
        0.0000000e+00, 0.0000000e+00]], dtype=float32)

In [104]:
sampled_subcubes[0][0].shape

(128, 128)

In [128]:
sampled_subcubes[0][0][0].shape

(128,)

In [126]:
sampled_subcubes[0][0][0]

array([2.7504480e-03, 1.0773394e-02, 6.7715812e-03, 3.3900790e-02,
       3.6065910e-02, 1.3558009e-01, 1.5031624e-01, 1.9707607e-01,
       6.8112306e-02, 1.8992294e-02, 4.0777373e-01, 2.9132786e+00,
       3.3128488e+00, 8.8518590e-01, 5.2593980e-02, 1.0766571e-02,
       7.6947524e-04, 2.6130702e-03, 6.6783349e-03, 7.7402918e-03,
       3.2814112e-02, 4.3553855e-02, 4.7022242e-02, 6.9133930e-02,
       3.0500990e-01, 5.5921251e-01, 5.1312238e-01, 1.1679703e+00,
       1.4452523e+00, 8.5361522e-01, 1.6657475e-01, 4.2290556e-01,
       2.9148549e-01, 4.0753132e-01, 4.7409824e-01, 2.8812128e-01,
       5.0406050e-02, 1.1771791e-01, 1.8866345e-01, 8.7852821e-02,
       8.9852504e-02, 1.8715803e-01, 1.1504405e-01, 1.3232701e-01,
       1.2600681e-01, 9.9342264e-02, 3.7569772e-02, 1.3884656e-01,
       6.2801696e-02, 9.9893667e-02, 9.6493252e-02, 7.6953590e-02,
       1.5652990e-01, 1.0701891e-01, 1.6782002e-01, 2.3484020e-01,
       1.7080885e-01, 1.3534851e-01, 2.2052497e-01, 1.7668775e

In [127]:
sampled_subcubes[0][0][0][0]

0.002750448

#### 3D Plotting of the HydrogenDataset outputs

In [108]:
sampled_subcubes[0].nonzero()

(array([  0,   0,   0, ..., 127, 127, 127]),
 array([  0,   0,   0, ..., 127, 127, 127]),
 array([  0,   1,   2, ..., 124, 125, 126]))

In [110]:
x,y,z = sampled_subcubes[0].nonzero()

In [133]:
print(x.shape)
print(y.shape)
print(z.shape)

(1363019,)
(1363019,)
(1363019,)


In [137]:
fig = plt.figure(figsize=(12,6))
ax = fig.add_subplot(111, projection='3d')

# x =[1,2,3,4,5,6,7,8,9,10]
# y =[5,6,2,3,13,4,1,2,4,8]
# z =[2,3,3,3,5,7,9,11,9,10]

for x_loc,y_loc,z_loc in zip(x,y,z):
    #print(sampled_subcubes[0][x_loc][y_loc][z_loc])
    ax.scatter3D(x_loc,
                 y_loc,
                 z_loc, 
                 c='blue', marker='o', 
                 alpha = sampled_subcubes[0][x_loc][y_loc][z_loc])

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

plt.show()

KeyboardInterrupt: 

ValueError: RGBA values should be within 0-1 range

<matplotlib.figure.Figure at 0x12b6ae588>

### DataLoader Implementation

In [140]:
dataloader = DataLoader(sampled_subcubes, 
                        batch_size=4,
                        shuffle=True, 
                        num_workers=4)

In [141]:
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched.size())

    # observe 4th batch and stop.
    if i_batch == 3:
        break

0 torch.Size([4, 128, 128, 128])
1 torch.Size([4, 128, 128, 128])
2 torch.Size([4, 128, 128, 128])
3 torch.Size([4, 128, 128, 128])
