In [176]:
import os
import numpy as np
import pandas as pd
import cv2
import xarray as xr
import glob
from typing import Literal

In [None]:
class DataLoader:
    def __init__(self, data_type: Literal["train", "test", "val"]) -> None:
        self.data_type = data_type
        match data_type:
            case "train":
                self.data_path = "../../data/playing_cards/train/"
            case "test":
                self.data_path = "../../data/playing_cards/test/"
            case "val":
                self.data_path = "../../data/playing_cards/val/"
    
    def load_data(self) -> xr.Dataset:
        images = []
        labels = os.listdir(self.data_path)

        # Separate columns for number and suit
        suits = []
        numbers = []

        for label in labels[:2]:
            if "joker" in label:
                continue

            for img_path in glob.glob(
                os.path.join(self.data_path, label, "*.jpg")
            ):
                img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
                images.append(img)

                number, _, suit = label.split()
                
                suits.append(suit)
                numbers.append(number)

        # Create an xr dataset with 3 cols: image, suit, number
        self.ds = xr.Dataset(
            {
                "image": (["image_num", "height", "width"], images), 
                "suit": (["image_num"], suits), 
                "number": (["image_num"], numbers)
            }, 
            coords={
                "image_num": range(len(images))
            }
        )

        return self.ds

In [179]:
train_data_loader = DataLoader("train")

train = train_data_loader.load_data()

In [180]:
train