In [None]:
"""I Target： Here we learn how to load made dataset, like divide into train/test, batch_size, shuffle

II Definition: In Torch, Dataset is also a defined class, if you want to load your own Dataset, 
   you should inherit from torch.utils.data.Dataloader. Actually, you can get images and labels by __iter__,

    PyTorch provides two data primitives: torch.utils.data.DataLoader and torch.utils.data.Dataset that allow you to use pre-loaded datasets as well as your own data. 
    Dataset stores the samples and their corresponding labels, 
    and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

III Instances:
    2.0 torch.utils.data.DataLoader https://pytorch.org/docs/stable/data.html
    2.1 __iter__  [magic method]
    2.2 __len__  [magic method]
    2.3.0 enumerate
    2.3.1 tqdm
    2.4 collate_fn

IV Compare 2 then Generalize

V Test in New instance 
"""

In [2]:
import os 
from matplotlib.pyplot import interactive
import torch

from torch.utils.data import Dataset,DataLoader
from torch.utils.data._utils import collate
from torchvision import transforms
from torchvision.datasets.mnist import MNIST



In [8]:
"""Use Dataset generate before"""

transform = transforms.Compose(
  [
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
  ]
)

train_dataset = MNIST(root="/home/hpczeji1/hpc-work/Codebase/Datasets/mnist_data",
                      train=True,
                      transform=transform,
                      target_transform=None,  # Eg1.2.1 : <class 'int'>
                      download=False)


In [11]:
def DataLoade_is_Iterable():
    """  make dataloade form dataset, and check them if they are Iterable"""

    from torch.utils.data import DataLoader
    train_loader = DataLoader(dataset=train_dataset,
                                batch_size=10000,
                                shuffle=False
    )

    from collections.abc import Iterable

    print(f"isinstance(train_dataset,Iterable):{isinstance(train_dataset,Iterable)}")  # False
    print(f"isinstance(train_loader,Iterable):{isinstance(train_loader,Iterable)}")  # True, loader is Iterable, but Dataset not 

    print("type(train_loader): {}".format(type(train_loader)))  # <class 'torch.utils.data.dataloader.DataLoader'>
    for batch in train_loader:
        print("type(batch): {}".format(type(batch)))  # <class 'list'>
        print("len(batch): {}".format(len(batch)))  # 2
        print("type(batch[0]): {}".format(type(batch[0])))  # <class 'torch.Tensor'>
        print("type(batch[1]): {}".format(type(batch[1])))  # <class 'torch.Tensor'>
        print("batch[0].shape: {}".format(batch[0].shape))  # torch.Size([10000, 1, 28, 28])
        print("batch[1].shape: {}".format(batch[1].shape))  # torch.Size([10000])
        break

In [13]:
def len_of_dataloader():

    from torch.utils.data import DataLoader
    train_loader = DataLoader(dataset=train_dataset,batch_size=10000,shuffle=True)

    print(f"len(train_loader：{len(train_loader)}")
    print(f"len(train_dataset)：{len(train_loader.dataset)}")

In [15]:
def enumerate_print():
    """enumerate can help you print batch-number and content"""

    from torch.utils.data import DataLoader
    train_loader = DataLoader(dataset=train_dataset,batch_size=10000,shuffle=True)

    for batch, (x,y) in enumerate(train_loader):
        print("batch: {}, type(x): {}, type(y): {}".format(batch, type(x), type(y)))
        # batch: 0, type(x): <class 'torch.Tensor'>, type(y): <class 'torch.Tensor'>
        print("batch: {}, x.shape: {}, y.shape: {}".format(batch, x.shape, y.shape))
        # batch: 0, x.shape: torch.Size([10000, 1, 28, 28]), y.shape: torch.Size([10000])
        break


In [16]:
# DataLoade_is_Iterable()
# len_of_dataloader()
enumerate_print()

batch: 0, type(x): <class 'torch.Tensor'>, type(y): <class 'torch.Tensor'>
batch: 0, x.shape: torch.Size([10000, 1, 28, 28]), y.shape: torch.Size([10000])
