In [23]:
# Import all libraries
import os
import yaml
import zipfile
import cv2
import joblib
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image


In [24]:
# Config file
RAW_PATH = "../../data/raw/"
PROCESSED_PATH = "../../data/processed/"

In [25]:
# Utils file
def clean(filename):
    if filename:
        for file in os.listdir(RAW_PATH):
            os.remove(os.path.join(RAW_PATH, file))
    else:
        raise ValueError("File not found".capitalize())
    
def config():
    with open("../../deafult_params.yml", "r") as file:
        return yaml.safe_load(file)
    
def load_pickle(value = None, filename = None):
    if value is not None and filename is not None:
        joblib.dump(value = value, filename=filename)

In [26]:
class Loader:
    def __init__(self, image_path = None, batch_size = 32):
        self.image_path = image_path
        self.batch_size = batch_size
        self.directory = None
        self.categories = None
        self.base_images = list()
        self.mask_images = list()
        self.is_mask = "mask"

    def base_transformation(self):
        return transforms.Compose(
            [
                transforms.Resize(
                    (config()["data"]["image_width"], config()["data"]["image_height"])
                ),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[
                        config()["data"]["transforms"],
                        config()["data"]["transforms"],
                        config()["data"]["transforms"],
                    ],
                    std=[
                        config()["data"]["transforms"],
                        config()["data"]["transforms"],
                        config()["data"]["transforms"],
                    ],
                ),
            ]
        )

    def mask_transformation(self):
            return transforms.Compose([
                transforms.Resize(
                    (config()["data"]["image_width"], config()["data"]["image_height"])
                ),
                transforms.ToTensor(),
                transforms.Grayscale(num_output_channels=1),
                transforms.Normalize(
                    mean=[
                        config()["data"]["transforms"],
                    ],
                    std=[
                        config()["data"]["transforms"],
                    ],
                ),
            ]
        )

    def unzip_folder(self):
        if os.path.exists(RAW_PATH):
            with zipfile.ZipFile(self.image_path, "r") as zip_ref:
                zip_ref.extractall(os.path.join(RAW_PATH))
        else:
            os.makedirs(RAW_PATH)

    def create_dataloader(self):
        self.directory = os.path.join(RAW_PATH, os.listdir(RAW_PATH)[0])
        self.categories = os.listdir(self.directory)

        for category in self.categories:
            folder_path = os.path.join(self.directory, category)
            for image in os.listdir(folder_path):
                if self.is_mask in image:
                    continue

                base_image = image.split(".")[0]
                extension = image.split(".")[1]
                mask_image = "{}.{}".format(base_image, extension)

                self.base_images.append(
                    self.base_transformation()(
                        Image.fromarray(cv2.imread(os.path.join(folder_path, image)))))
                
                self.mask_images.append(
                    self.mask_transformation()(
                        Image.fromarray(cv2.imread(os.path.join(folder_path, mask_image),
                                                    cv2.IMREAD_GRAYSCALE
                ))))
                
        if os.path.exists(PROCESSED_PATH):
            dataloader = DataLoader(
                dataset=list(zip(self.base_images, self.mask_images)), batch_size=self.batch_size, shuffle=True
                )
            load_pickle(value=self.base_images, filename=os.path.join(PROCESSED_PATH, "base_images.pkl"))
            load_pickle(value=self.mask_images, filename=os.path.join(PROCESSED_PATH, "mask_images.pkl"))
            load_pickle(value=dataloader, filename=os.path.join(PROCESSED_PATH, "dataloader.pkl"))
            
            return dataloader
            
        else:
            raise Exception("PROCESSED_PATH does not exist".capitalize())
                


if __name__ == "__main__":
    loader = Loader(
        image_path="/Users/shahmuhammadraditrahman/Desktop/semantic.zip", batch_size=32
    )
    loader.unzip_folder()
    dataloader = loader.create_dataloader()