In [9]:
import os
import zipfile
import joblib as pkl
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader

In [10]:
def pickle(value = None, filename = None):
    if value and filename:
        pkl.dump(value = value, filename=filename)
    else:
        raise ValueError("value and filename are required".capitalize())

In [13]:
to_extract = "../data/raw/"
to_save = "../data/processed/"

In [34]:
class Loader:
    def __init__(self, image_path = None, batch_size = 64, image_size = 64, normalized = True):
        self.image_path = image_path
        self.batch_size = batch_size
        self.image_size = image_size
        self.use_normalized = normalized

    def unzip_images(self):
        with zipfile.ZipFile(self.image_path, "r") as zip_ref:
            if os.path.exists(to_extract):
                zip_ref.extractall(to_extract)
            else:
                raise Exception("Extracting images failed".capitalize())

    def _normalized(self):
        if self.use_normalized:
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize(self.image_size),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 5, 0.5]),
                transforms.Grayscale(num_output_channels=1)
            ])

            return transform

    @staticmethod
    def class_to_idx(dataset = None):
        if dataset is not None:
            return dataset.class_to_idx

    def create_dataloader(self):
        if os.path.exists(to_extract):
            datasets = ImageFolder(root=os.path.join(to_extract, "Dataset"), transform=self._normalized())
            dataloader = DataLoader(datasets, batch_size=self.batch_size, shuffle=True)

            if os.path.exists(to_save):

                try:
                    pickle(value=dataloader, filename=os.path.join(to_save, "dataloader.pkl"))
                    pickle(value=Loader.class_to_idx(dataset=datasets), filename=os.path.join(to_save, "dataset.pkl"))
                except Exception as e:
                    print(e)
            else:
                raise Exception("Creating dataloader failed".capitalize())
        else:
            raise Exception("Extracting images failed from the create dataloader method".capitalize())

        return dataloader, datasets.class_to_idx


if __name__ == "__main__":
    loader = Loader(
        image_path="/Users/shahmuhammadraditrahman/Desktop/archive.zip",
        batch_size=64,
        image_size=64,
        normalized=True,
    )
    
    loader.unzip_images()
    dataloader, labels = loader.create_dataloader()