In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms as T
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid
import torchvision.transforms.functional as TF

import matplotlib.pyplot as plt
import numpy as np

In [2]:
class HyperParameters(dict):
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
        
    def __repr__(self):
        return '\n'.join(f"Parameter: {k:<16} Value: {v}" for k,v in self.items())
            
        

In [3]:
class CIFAR10DataSet():
    def __init__(self, hparams):
        super().__init__()
        self.hparams = HyperParameters(hparams)
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2471, 0.2435, 0.2616)
        
    def train_dataloader(self):
        transform = T.Compose(
            [
                T.Resize((self.hparams.img_size, self.hparams.img_size)),
                T.ToTensor(),
                T.Normalize(self.mean, self.std),
            ]
        )
        dataset = CIFAR10(
            root=self.hparams.data_dir, 
            train=True, 
            download=True, 
            transform=transform,
        )
        dataloader = DataLoader(
            dataset, 
            batch_size=self.hparams.batch_size, 
            shuffle=True, 
            num_workers=self.hparams.num_workers,
        )
        return dataloader
    
    def get_classes(self):
        return ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
    def _imshow(self, imgs):
        if not isinstance(imgs, list):
            imgs = [imgs]
        fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
        for i, img in enumerate(imgs):
            img = img / 2 + 0.5
            img = img.detach()
            img = TF.to_pil_image(img)
            axs[0, i].imshow(np.asarray(img))
            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        
    def show_grid(self, dataloader):
        it = iter(dataloader)
        images, labels = next(it)
        self._imshow(make_grid(images))
        classes = self.get_classes()
        print(' '.join(f'{classes[labels[j]]:s}' for j in range(len(images))))

    

In [4]:
cds_config = {
    'data_dir':"./data",
    'batch_size': 32,
    'num_workers': 4,
    'img_size':224,
}


In [5]:
cds = CIFAR10DataSet(cds_config)
dl = cds.train_dataloader()

Files already downloaded and verified


In [None]:
cds.show_grid(dl)

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/ajithj/mambaforge/envs/vit-step-by-step/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/ajithj/mambaforge/envs/vit-step-by-step/lib/python3.10/multiprocessing/spawn.py", line 125, in _main
    prepare(preparation_data)
  File "/Users/ajithj/mambaforge/envs/vit-step-by-step/lib/python3.10/multiprocessing/spawn.py", line 236, in prepare
    _fixup_main_from_path(data['init_main_from_path'])
  File "/Users/ajithj/mambaforge/envs/vit-step-by-step/lib/python3.10/multiprocessing/spawn.py", line 287, in _fixup_main_from_path
    main_content = runpy.run_path(main_path,
  File "/Users/ajithj/mambaforge/envs/vit-step-by-step/lib/python3.10/runpy.py", line 269, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/Users/ajithj/mambaforge/envs/vit-step-by-step/lib/python3.10/runpy.py", line 96, in _run_module_c