<a href="https://colab.research.google.com/github/SCCSMARTCODE/Deep-Learning-00/blob/main/Cifar10_DataLoader/cifar_data_loader.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import requests
import os
import torch
import tarfile
import numpy as np

In [None]:
class LoadCifar10:
    """
    This class loads our Cifar 10 dataset
    """
    file_path="https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    file_name="zip_cifar_10.tar.gz"


    def __init__(self, root, download=False, transforms=None, train=True):
        self.root = root
        self.download = download
        self.transforms = []
        if transforms:
            try:
                self.transforms.extend(transforms)
            except:
                self.transforms.append(transforms)
        self.train = train
        self.dataset_folder = os.path.join(self.root, "cifar-10-batches-py")
        self.dataset=[]

        if not root:
            return None

        # download if self.download is True
        if self.download:
            self.download_dataset()

        # loading dataset
        if not os.path.exists(self.dataset_folder):
            print("File doesn't exist\nset download to True to download it")
            return None

        # locate the necessary file based on requirement
        if self.train:
            file_names = [x for x in os.listdir(self.dataset_folder) if "data_batch" in x]
        else:
            file_names = [x for x in os.listdir(self.dataset_folder) if "test_batch" in x]

        # convert the files to useable data
        for file_name in file_names:
            out = self.unpickle(os.path.join(self.dataset_folder, file_name))
            batch_images = list(out.get(b'data'))
            batch_images = map(self.format_image, batch_images)

            batch_dataset = list(zip(batch_images, out.get(b'labels')))

            self.dataset.extend(batch_dataset)


    def download_dataset(self):
        """
        This is the function that downloads and extract the cifar10 dataset from the main website
        """

        if os.path.exists(self.dataset_folder):
            print("File exists...")
            return
        response = requests.get(self.file_path)
        if response.status_code != 404:
            downloaded_file_path = os.path.join(self.root, self.file_name)
            with open(downloaded_file_path, 'wb') as f:
                f.write(response.content)
        else:
            print("File Not Found")
            exit(-1)
        print("Cifar10 Downloaded Successfully...")

        # Extract the zip file
        with tarfile.open(downloaded_file_path, "r:gz") as f:
            f.extractall(self.root)

        # delete the zip file
        os.remove(downloaded_file_path)

        print("Cifar10 Extracted Successfully...")


    def unpickle(self, file):
        """
        This function helps in converting compressedfile into usable dictionary
        """
        import pickle
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
        return dict

    def format_image(self, image, size=(3,32,32)):
        """
        This function helps us with transforming our image
        """
        image = image.reshape(size)

        if self.transforms:
            for transform in self.transforms:
                image = transform(image)
        return image

    def __str__(self):
        return self.dataset


In [None]:
from torchvision.transforms import ToTensor

output = LoadCifar10(root="/content/drive/MyDrive/Deep Learning/CIFAR_10", train=True, download=True)

In [None]:
print(output.__str__()[0])

(tensor([[[0.6980, 0.7059, 0.6941,  ..., 0.4392, 0.4392, 0.4039],
         [0.6902, 0.6980, 0.6863,  ..., 0.4196, 0.4000, 0.3765],
         [0.7412, 0.7490, 0.7373,  ..., 0.4196, 0.3961, 0.3608]],

        [[0.6980, 0.7020, 0.6941,  ..., 0.4431, 0.4392, 0.3922],
         [0.6902, 0.6941, 0.6863,  ..., 0.4275, 0.4039, 0.3647],
         [0.7412, 0.7451, 0.7373,  ..., 0.4235, 0.4000, 0.3529]],

        [[0.6980, 0.7059, 0.6980,  ..., 0.4471, 0.4431, 0.4039],
         [0.6902, 0.6980, 0.6902,  ..., 0.4314, 0.4039, 0.3725],
         [0.7412, 0.7490, 0.7412,  ..., 0.4314, 0.4039, 0.3686]],

        ...,

        [[0.6667, 0.6784, 0.6706,  ..., 0.3922, 0.4000, 0.3608],
         [0.6588, 0.6706, 0.6627,  ..., 0.3804, 0.3725, 0.3294],
         [0.7059, 0.7137, 0.7059,  ..., 0.3686, 0.3647, 0.3137]],

        [[0.6588, 0.6706, 0.6627,  ..., 0.3843, 0.4000, 0.3647],
         [0.6510, 0.6627, 0.6549,  ..., 0.3686, 0.3647, 0.3373],
         [0.6941, 0.7059, 0.6980,  ..., 0.3647, 0.3569, 0.3137]],

