Skip to content
Switch branches/tags
Go to file
Cannot retrieve contributors at this time
from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
import pickle
import as data
from torchvision.datasets.utils import download_url, check_integrity
class HTRU1(data.Dataset):
"""`HTRU1 <>`_ Dataset.
root (string): Root directory of dataset where directory
``htru1-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
base_folder = 'htru1-batches-py'
url = ""
filename = "htru1-batches-py.tar.gz"
tgz_md5 = 'e7b063301ada3eb50f212afeea185a36'
train_list = [
['data_batch_1', '3a085bdcc186a8f9d8f120adcde8f3d2'],
['data_batch_2', '12e4ff7648ffc2047ff4774a6074bc0d'],
['data_batch_3', '12c0dd52b4febe4132917cf733ceae2c'],
['data_batch_4', 'b377c8a723603c4addf32831607f13e7'],
['data_batch_5', 'f6bc78dec3d75e3db005a7a9b7d910c0'],
test_list = [
['test_batch', 'dc2d5f6ebf826eff1cbb0942705796b9'],
meta = {
'filename': 'batches.meta',
'key': 'label_names',
'md5': '5429d773dafec7781e0eeacb29768819',
def __init__(self, root, train=True,
transform=None, target_transform=None,
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
if self.train:
downloaded_list = self.train_list
downloaded_list = self.test_list = []
self.targets = []
# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(f)
entry = pickle.load(f, encoding='latin1')['data'])
if 'labels' in entry:
self.targets.extend(entry['fine_labels']) = np.vstack(, 3, 32, 32) =, 2, 3, 1)) # convert to HWC
def _load_meta(self):
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
if not check_integrity(path, self.meta['md5']):
raise RuntimeError('Dataset metadata file not found or corrupted.' +
' You can use download=True to download it')
with open(path, 'rb') as infile:
if sys.version_info[0] == 2:
data = pickle.load(infile)
data = pickle.load(infile, encoding='latin1')
self.classes = data[self.meta['key']]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
def __getitem__(self, index):
index (int): Index
tuple: (image, target) where target is index of the target class.
img, target =[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(
def _check_integrity(self):
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5):
return False
return True
def download(self):
import tarfile
if self._check_integrity():
print('Files already downloaded and verified')
download_url(self.url, self.root, self.filename, self.tgz_md5)
# extract file
with, self.filename), "r:gz") as tar:
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str