In [1]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset

In [2]:
class WiFiDataset(Dataset):
    def __init__(self, root_dir, task='classify', transform=None, config=None):
        assert task in {'classify', 'pose'}, f'invalid task: {task}'
        
        self.root_dir = root_dir
        self.task = task
        self.transform = transform
        
        self.dataset_indexs = {1, 2, 3, 4}
        self.filename_part_mapping = {
            "dataset": 0,
            "scene": 1,
            "receiver_position": 2,
            "body_position": 3,
            "body_orientation": 4,
            "action": 5,
            "person": 6,
            "action_repetition": 7,
            "spectrum": 8
        }

        if config is None:
            self.spectrums = self.get_all_spectrums()
            self.ground_truths = self.get_corresponding_ground_truths(self.spectrums)
        elif 'dataset_index_set' in config:
            self.spectrums, self.ground_truths = self.get_specific_datasets(config['dataset_index_set'])
        elif 'body_position' in config and 'body_orientation' in config and 'person' in config:
            self.spectrums, self.ground_truths = self.get_specific_data(
                config['body_position'], config['body_orientation'], config['person']
            )
        else:
            self.spectrums, self.ground_truths = [], []

    def __len__(self):
        return len(self.spectrums)

    def __getitem__(self, idx):
        spectrum, ground_truth = self.spectrums[idx], self.ground_truths[idx]
        spectrum, ground_truth = np.load(spectrum), np.load(ground_truth)
        if self.transform:
            spectrum = self.transform(spectrum)
        ground_truth = torch.from_numpy(ground_truth)
        return spectrum, ground_truth

    def get_all_spectrums(self):
        spectrums = []
        for root, _, files in os.walk(self.root_dir):
            for file in files:
                if file.endswith('.npy'):
                    spectrums.append(os.path.join(root, file))
        return spectrums
        
    def get_corresponding_ground_truths(self, spectrums):
        ground_truths = []
        for spectrum in spectrums:
            if self.task == 'classify':
                ground_truths.append('classify')
            elif self.task == 'pose':
                ground_truths.append('pose')
        return ground_truths

    def get_filename_parts(self, filepath):
        return os.path.splitext(os.path.basename(filepath))[0].split('_')
        
    def get_specific_datasets(self, dataset_index_set):
        assert isinstance(dataset_index_set, set), 'the dataset_index_set is not a set'
        assert dataset_index_set.issubset(self.dataset_indexs), f'invalid dataset index set: {dataset_index_set}'
        
        all_spectrums = self.get_all_spectrums()
        dataset_part_index = self.filename_part_mapping['dataset']
        
        spectrums = [
            spectrum for spectrum in all_spectrums
            if self.get_filename_parts(spectrum)[dataset_part_index] in dataset_index_set
        ]

        ground_truths = self.get_corresponding_ground_truths(spectrums)
        
        return spectrums, ground_truths

    def get_specific_data(self, body_position=None, body_orientation=None, person=None):
        filename_part_indexs = {
            'body_position': self.filename_part_mapping['body_position'],
            'body_orientation': self.filename_part_mapping['body_orientation'],
            'person': self.filename_part_mapping['person']
        }
        
        filters = {
            'body_position': body_position,
            'body_orientation': body_orientation,
            'person': person
        }

        spectrums = self.get_all_spectrums()
        for key, criterion in filters.items():
            if criterion is not None:
                index = filename_part_indexs[key]
                spectrums = [spectrum for spectrum in spectrums if self.get_filename_parts(spectrum)[index] == criterion]
    
        ground_truths = self.get_corresponding_ground_truths(spectrums)
        
        return spectrums, ground_truths