In [1]:
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

In [2]:
def load_dataset(path,filename):
    f = h5py.File(path+filename)
    input_dataset = np.asarray([f[element[0]][:] for element in f['input_dataset']])
    output_dataset = np.asarray([f[element[0]][:] for element in f['output_dataset']])
    f.close()
    return input_dataset,output_dataset

In [3]:
class DeconvolutionDataset(Dataset):
    """dataset."""

    def __init__(self, input_dataset,output_dataset, transform=None):
        """

        """
        self.input_dataset = input_dataset
        self.output_dataset = output_dataset
        self.transform = transform

    def __len__(self):
        return len(self.output_dataset)

    def __getitem__(self, idx):

        sample = {'input': self.input_dataset[idx,:,:,:], 'output': np.reshape(self.output_dataset[idx,:,:],(1,128,128))}
        
        if self.transform:
            sample = self.transform(sample)

        return sample

In [4]:
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        input_image, output_image = sample['input'], sample['output']
        return {'input': torch.from_numpy(input_image),
                'output': torch.from_numpy(output_image)}

In [5]:
path = './data/'
filename = 'dataset_batch_2.mat'
input_dataset,output_dataset = load_dataset(path,filename)
print(input_dataset.shape,output_dataset.shape)

(528, 200, 128, 128) (528, 128, 128)


In [23]:
training_dataset = DeconvolutionDataset(input_dataset,output_dataset,transform=ToTensor())



In [24]:
print(training_dataset[1])

{'input': tensor([[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,
             0.0000,    0.0000],
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,
             0.0000,    0.0000],
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,
             0.0000,    0.0000],
         ...,
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,
             0.0000,    0.0000],
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,
             0.0000,    0.0000],
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,
             0.0000,    0.0000]],

        [[   0.0000,    0.0000,    0.0000,  ...,    0.0000,
             0.0000,    0.0000],
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,
             0.0000,    0.0000],
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,
             0.0000,    0.0000],
         ...,
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,
             0.0000,    0.0000],
         [   0.0000,    0.0000

In [25]:
dataloader = DataLoader(training_dataset, batch_size=4,shuffle=True, num_workers=4)

In [26]:
for batch_idx, sample in enumerate(training_dataset):
        print(batch_idx,sample['input'].size(),sample['output'].size())

0 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
1 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
2 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
3 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
4 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
5 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
6 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
7 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
8 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
9 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
10 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
11 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
12 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
13 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
14 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
15 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
16 torch.Size([200, 128, 128]) torch.Size([1, 128, 128])
17 torch.Size([200, 128, 128]) torch.Size