In [6]:
data_path = r'D:\Data Deep Learning\datamotor\motor\motor'

print(os.path.abspath(data_path))

D:\Data Deep Learning\datamotor\motor\motor


In [12]:
import json
import os
from PIL import Image
from torchvision import transforms as T
from torch.utils.data import Dataset
import numpy as np

import mortobike_project as mp


class MotorBikeDataset(Dataset):
    def __init__(self, config_path: str, session: str = 'train', data_mode: str = 'csv', **kwargs):
        """
            In this class, there are two data mode to choose from:
            - `csv`: You need to provide a folder containing the images, and a csv file containing the labels

            - `ssl`: In this mode, you just need to import list of folder paths, which are divided into classes already

        Args:
            `config_path` (str): The path to the config file
            `session` (str, optional): The session of the dataset, must be in [`train`, `val`, `test`]
            `data_mode` (str, optional): The data mode. Defaults to `csv`.
            `kwargs`: Other arguments:

            - For `csv` mode:
                `folder_path` (str): The folder containing the images
                `csv_path` (str): The csv file containing the labels
            - For `ssl` mode:
                'folder_paths' (list): The list of folder paths, each folder path is a string to the folder containing the images of a class
        """

        assert session in ['train', 'val', 'test'], 'Invalid session, must be in [train, val, test]'

        assert data_mode in ['csv', 'ssl'], 'Invalid data mode, must be in [csv, ssl]'

        self.data_mode = data_mode
        self.session = session
        self.kwargs = kwargs
        self.config_path = config_path

        self.labels = {}

        if not os.path.exists(config_path):
            raise ValueError(f'Config path {config_path} does not exist')

        with open(os.path.join(self.config_path, 'class.json'), 'r') as f:
            self.config_class: dict = json.load(f)

        # Define the image transform
        self.transform = mp.Transform(session)

        # Load the dataset in the folder
        self.load_dataset()

    def load_dataset(self):
        if self.data_mode == 'csv':
            # TODO: Load the dataset from and match the label from the csv file
            pass
        else:
            self.folder_paths = self.kwargs.get('folder_paths', None)

            # 1 folder is "xe_so", 2 folder is "xe_ga", 3, 4, 5 folder is "others"
            self.classes = os.listdir(self.folder_paths[0])
            
            for folder in self.folder_paths:
                for folder_class in self.classes:
                    for img in os.listdir(os.path.join(folder, folder_class)):
                        img_path = os.path.join(folder, folder_class, img)
                        self.labels[img_path] = int(folder_class) if folder_class in ('1', '2') else 3

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        img_path = list(self.labels.keys())[index]
        label = self.labels[img_path]

        img = Image.open(img_path).convert('RGB') 

        if self.transform is not None:
            # Convert the image to numpy array
            img_np = np.array(img)
            
            img = self.transform(img_np)

        return img, label


In [17]:
dataset = MotorBikeDataset(
    config_path='src/mortobike_project/config',
    session='test',
    data_mode='ssl',
    folder_paths=[os.path.join(data_path, x) for x in ('test', 'train', 'val')]
)

dataset[20000]

(tensor([[[ 0.0741,  0.0569,  0.0227,  ...,  0.3309,  0.3138,  0.3309],
          [ 0.2111,  0.1939,  0.1597,  ...,  0.8276,  0.8618,  0.8789],
          [ 0.3481,  0.3309,  0.3138,  ...,  0.9988,  1.0331,  1.0502],
          ...,
          [-0.2513, -0.1657, -0.0629,  ..., -0.1999, -0.0116,  0.2967],
          [-0.3712, -0.2856, -0.1486,  ..., -0.1486,  0.0569,  0.3138],
          [-0.3883, -0.2856, -0.1486,  ..., -0.0972,  0.1083,  0.3138]],
 
         [[ 0.2052,  0.1877,  0.1527,  ...,  0.4328,  0.4153,  0.4328],
          [ 0.3452,  0.3277,  0.2927,  ...,  0.9405,  0.9580,  0.9930],
          [ 0.4853,  0.4678,  0.4503,  ...,  1.0980,  1.1331,  1.1506],
          ...,
          [-0.0749,  0.0126,  0.1352,  ..., -0.9153, -0.7052, -0.3725],
          [-0.1975, -0.1099,  0.0301,  ..., -0.8978, -0.6702, -0.3901],
          [-0.2150, -0.1099,  0.0301,  ..., -0.8627, -0.6352, -0.4076]],
 
         [[ 0.4265,  0.4091,  0.3742,  ...,  0.4265,  0.4091,  0.4265],
          [ 0.5659,  0.5485,

In [21]:
# A sample of y_hat to understand the function y_hat.argmax(dim=1)
import torch

y = torch.tensor([2, 0, 1, 1]) # 3 samples, each sample has a label

y_hat = torch.tensor([
    [0.1, 0.2, 0.9], # y_hat[0] = 2
    [0.8, 0.1, 0.1], # y_hat[1] = 0
    [0.3, 0.6, 0.1]  # y_hat[2] = 1
])

y_hat.argmax(dim=0) # The index of the max value in each column
y_hat.argmax(dim=1) # The index of the max value in each row (sample)

tensor([2, 0, 1])

In [26]:
(y_hat.argmax(dim=1) == y).float().mean().item() # Compare the index of the max value in each row with the label

1.0