In [1]:
from datasets import load_dataset

In [16]:
import pytorch_lightning as pl
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import AutoAugmentPolicy, AutoAugment

In [125]:


class TinyImageNet(pl.LightningDataModule):
    def __init__(self, batch_size=512, num_workers=28):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.data_dir = './data'
        self.features = None
        self.transform = transforms.Compose([
            # transforms.Resize((64, 64)),
            AutoAugment(policy=AutoAugmentPolicy.IMAGENET),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
    def prepare_data(self):
        load_dataset('zh-plus/tiny-imagenet', cache_dir=self.data_dir)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            dataset = load_dataset('zh-plus/tiny-imagenet', cache_dir=self.data_dir)
            
            train_val_ds = dataset['train'].train_test_split(test_size=0.1)
            self.train_dataset = train_val_ds['train'].with_transform(self.apply_transform)
            self.val_dataset = train_val_ds['test'].with_transform(self.apply_transform)
            self.test_dataset = dataset['valid'].with_transform(self.apply_transform)
            self.features = self.train_dataset.features

    def apply_transform(self, example):
        example['image'] = [self.transform(img.convert('RGB')) for img in example['image']]
        return example

    @staticmethod
    def collate_fn(batch):
        images = torch.stack([item['image'] for item in batch])
        labels = torch.tensor([item['label'] for item in batch])
        return images, labels

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, collate_fn=self.collate_fn)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True, collate_fn=self.collate_fn)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True, collate_fn=self.collate_fn)




In [126]:
tiny_imagenet = TinyImageNet()
tiny_imagenet.prepare_data()
tiny_imagenet.setup('fit')

In [139]:
import json
with open('label2classname.json','r') as file:
    label2class = json.loads(file.readline())
    label2class = {label:name for k,(label,name) in label2class.items()}

In [156]:
[label2class.get(label) for label in tiny_imagenet.features['label'].names ]

['goldfish',
 'European_fire_salamander',
 'bullfrog',
 'tailed_frog',
 'American_alligator',
 'boa_constrictor',
 'trilobite',
 'scorpion',
 'black_widow',
 'tarantula',
 'centipede',
 'koala',
 'jellyfish',
 'brain_coral',
 'snail',
 'sea_slug',
 'American_lobster',
 'spiny_lobster',
 'black_stork',
 'king_penguin',
 'albatross',
 'dugong',
 'Yorkshire_terrier',
 'golden_retriever',
 'Labrador_retriever',
 'German_shepherd',
 'standard_poodle',
 'tabby',
 'Persian_cat',
 'Egyptian_cat',
 'cougar',
 'lion',
 'brown_bear',
 'ladybug',
 'grasshopper',
 'walking_stick',
 'cockroach',
 'mantis',
 'dragonfly',
 'monarch',
 'sulphur_butterfly',
 'sea_cucumber',
 'guinea_pig',
 'hog',
 'ox',
 'bison',
 'bighorn',
 'gazelle',
 'Arabian_camel',
 'orangutan',
 'chimpanzee',
 'baboon',
 'African_elephant',
 'lesser_panda',
 None,
 'academic_gown',
 'altar',
 'backpack',
 'bannister',
 'barbershop',
 'barn',
 'barrel',
 'basketball',
 'bathtub',
 'beach_wagon',
 'beacon',
 'beaker',
 'beer_bottle

In [119]:
for data,label in tiny_imagenet.train_dataloader():
    print(data.shape,label)
    break

torch.Size([512, 3, 64, 64]) tensor([115, 131, 165,  77,  63,  16, 100,  87, 111, 180,  46, 177, 182, 133,
         99, 125,  61, 189, 143, 158, 102, 193, 132, 104,  10,  72,  60,  19,
        126, 107, 155,  45, 160, 135, 155,  59, 117, 148, 158, 160,  88,  13,
        177,  86,  65,  62, 162, 108,  94, 133,  23,  72,  14, 163, 190, 153,
         66, 161,  26,  90,   4,   3, 196,  18,  73,  74, 199,  48,  48,  46,
        145,  30, 194, 125,  30,  56,  42,  53, 146,   6,  88, 180, 142,  55,
        114,  85, 111, 178,  30, 182,  54, 123,  33,  87,  23, 104,   2, 144,
         44, 180,   5, 149, 160,   0, 113, 161,   9,  12, 142, 179,  72, 101,
         33,  71,  17, 190, 113, 146,  43,  47, 131, 116,  91,  52,  14,  46,
         41,  75, 121,  80, 179, 120,  68, 157,  29, 188, 119,  84,  92, 188,
          0,   8,  89,  53,  27,   3, 158,  68,  35, 170,  74,  50, 159, 182,
         56,  71,  50,  39,  64, 113,  93,  42,   9,   7,  75, 115,  31, 141,
         80,  36, 179, 167,  53, 18

In [1]:
import torchvision

# Corrected dataset imports with required parameters
# coco_detection = torchvision.datasets.CocoDetection(root='./data', annFile='annotations.json')
# celeba = torchvision.datasets.CelebA(root='./data', split='all', download=True)
# voc_segmentation = torchvision.datasets.VOCSegmentation(root='./data', year='2012', image_set='train', download=True)
# voc_detection = torchvision.datasets.VOCDetection(root='./data', year='2012', image_set='train', download=True)
# coco_captions = torchvision.datasets.CocoCaptions(root='./data', annFile='annotations.json')
ucf101 = torchvision.datasets.UCF101(root='./data', annotation_path='ucfTrainTestlist', frames_per_clip=1, step_between_clips=1, fold=1, train=True, transform=None, target_transform=None, download=True)
moving_mnist = torchvision.datasets.MovingMNIST('./data',download=True)
img = torchvision.datasets.FlyingChairs('./data',)
img = torchvision.datasets.Kitti2012Stereo('./data',)
img = torchvision.datasets.LFWPairs('./data',download=True)

Using downloaded and verified file: ./data/VOCtrainval_11-May-2012.tar
Extracting ./data/VOCtrainval_11-May-2012.tar to ./data
Using downloaded and verified file: ./data/VOCtrainval_11-May-2012.tar
Extracting ./data/VOCtrainval_11-May-2012.tar to ./data
Downloading http://vis-www.cs.umass.edu/lfw/lfw-funneled.tgz to ./data/lfw-py/lfw-funneled.tgz


100%|█████████████████████████| 243346528/243346528 [06:39<00:00, 608430.34it/s]


Extracting ./data/lfw-py/lfw-funneled.tgz to ./data/lfw-py
Downloading http://vis-www.cs.umass.edu/lfw/pairs.txt to ./data/lfw-py/pairs.txt


100%|███████████████████████████████| 155335/155335 [00:00<00:00, 231347.09it/s]
