In [26]:
import os
import torch
import torchvision.datasets as datasets
import torch.utils.data as data
from PIL import Image
from glob import glob
import numpy as np
class ImageNet1000_limit_images(data.Dataset):
    def __init__(self,root,split='train', 
                    transform=None, nb_images=83):
        super(ImageNet1000_limit_images, self).__init__()

        self.root = os.path.join(root, '%s' %(split))
        self.transform = transform
        self.split = split
#         if split=='train':
        self.nb_images = nb_images
        subdirs = glob(self.root+"/*/")

        # Gather the files (sorted)
        imgs = []
        self.targets = []
        for i, subdir in enumerate(subdirs):
            subdir_path = os.path.join(subdir)
            files = sorted(glob(os.path.join(self.root, subdir, '*.JPEG')))
            
            len_files = len(files)
            limit_nb_images = min(len_files,self.nb_images)
            
            files = files[:limit_nb_images]
            
            for f in files:
                imgs.append((f, i))
            self.targets.append(i)
#         print(subdir_path)
        self.imgs = imgs
        self.targets = np.array(self.targets)

    def get_image(self, index):
        path, target = self.imgs[index]
        with open(path, 'rb') as f:
            img = Image.open(f).convert('RGB')
        img = self.resize(img) 
        return img

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        path, target = self.imgs[index]
        with open(path, 'rb') as f:
            img = Image.open(f).convert('RGB')
        im_size = img.size

        if self.transform is not None:
            img = self.transform(img)

        return img,target

In [32]:
dataset = ImageNet1000_limit_images(root='/home/ajha/datasets/Imagenet_downloads/',split='train', transform=None)

In [33]:
it_data = iter(dataset)

In [34]:
next(it_data)

(<PIL.Image.Image image mode=RGB size=500x377 at 0x7F7D9B4477B8>, 0)

In [37]:
dataset.imgs[83]

('/home/ajha/datasets/Imagenet_downloads/train/n04254777/n04254777_10024.JPEG',
 1)

In [8]:
dataset.targets

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
       68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
       85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])

In [1]:
import glob

In [5]:
aa = glob.glob('/home/ajha/datasets/Imagenet_downloads/train/*/')

In [6]:
len(aa)

1000

In [4]:
aa

['/home/ajha/datasets/Imagenet_downloads/train/']

In [7]:
aa[:10]

['/home/ajha/datasets/Imagenet_downloads/train/n03873416/',
 '/home/ajha/datasets/Imagenet_downloads/train/n04254777/',
 '/home/ajha/datasets/Imagenet_downloads/train/n03476991/',
 '/home/ajha/datasets/Imagenet_downloads/train/n07714571/',
 '/home/ajha/datasets/Imagenet_downloads/train/n01695060/',
 '/home/ajha/datasets/Imagenet_downloads/train/n02497673/',
 '/home/ajha/datasets/Imagenet_downloads/train/n02346627/',
 '/home/ajha/datasets/Imagenet_downloads/train/n03840681/',
 '/home/ajha/datasets/Imagenet_downloads/train/n07695742/',
 '/home/ajha/datasets/Imagenet_downloads/train/n01440764/']