In [18]:
import json
from torch.utils.data import Dataset, DataLoader

# Load CSI_data.json
with open('CSI_data.json') as f:
    data = json.load(f)

# Define the custom Dataset class


class CSIDataset(Dataset):
    def __init__(self, data, filter_func):
        self.data = data
        self.filter_func = filter_func
        self.filtered_data = self.filter_data()  # Filter data using the function

    def filter_data(self):
        filtered = []
        for class_name, files in self.data.items():
            # Show class and entry count
            print(f"Class: {class_name}, Files: {len(files)} entries")
            for file_root in files:
                root_parts = split_root(file_root)
                try:
                    # Pass root_parts, not class_name
                    if self.filter_func(root_parts, file_root):
                        filtered.append(file_root)
                except Exception as e:
                    print(f"Error filtering {file_root}: {e}")
        return filtered

    def __len__(self):
        return len(self.filtered_data)

    def __getitem__(self, idx):
        return self.filtered_data[idx]

# Adjusted filter functions

# Helper function to split file root


def split_root(file_root):
    parts = file_root.split('/')
    if len(parts) != 6:
        raise ValueError(f"File root is not correctly formatted: {file_root}")
    return {
        "CLASS_NAME": parts[0],
        "GENDER_COUNT": parts[2],
        "POSITION": parts[3],
        "TIME": parts[4],
        "FILE_NAME": parts[5]
    }

In [19]:
# Create a DataLoader for filtered data
def get_dataloader(data, filter_func):
    dataset = CSIDataset(data, filter_func)
    return DataLoader(dataset, batch_size=1, shuffle=False)

# Function to save filtered data to JSON file


def save_filtered_data(dataloader, filename):
    # Collect filtered results from the dataloader
    filtered_results = [item for item in dataloader]

    # Ensure each item in filtered_results is treated as a string, not a list
    results_dict = {str(index): value if isinstance(
        value, str) else value[0] for index, value in enumerate(filtered_results)}

    # Write the dictionary to a JSON file
    with open(filename, 'w') as f:
        json.dump(results_dict, f, indent=4)

In [20]:
# Requirement 1: CLASS_NAME contains "Env3"
def filter_req_1(root_parts, file_root):
    return "Env3" in root_parts["CLASS_NAME"]


# Example usage for Requirement 1
dataloader_1 = get_dataloader(data, filter_req_1)
save_filtered_data(dataloader_1, 'A1_313831013_游明睿_1.json')

Class: train, Files: 895130 entries
Class: val, Files: 8433 entries
Class: test, Files: 9621 entries


In [21]:
# Requirement 2: "THE_GENDER_AND_COUNT" contains 2 females with no limit on the number of males
def filter_req_2(root_parts, file_root):
    gender_count = root_parts["GENDER_COUNT"]
    # Count the number of females ("F" labels)
    num_females = gender_count.count('F')
    return num_females == 2


# Requirement 2
dataloader_2 = get_dataloader(data, filter_req_2)
save_filtered_data(dataloader_2, 'A1_313831013_游明睿_2.json')

Class: train, Files: 895130 entries
Class: val, Files: 8433 entries
Class: test, Files: 9621 entries


In [22]:
def filter_req_3(root_parts, file_root):
    return "Female" in root_parts["GENDER_COUNT"]


# Requirement 2
dataloader_3 = get_dataloader(data, filter_req_3)
save_filtered_data(dataloader_3, 'A1_313831013_游明睿_3.json')

Class: train, Files: 895130 entries


Class: val, Files: 8433 entries
Class: test, Files: 9621 entries


In [23]:
# Requirement 4: "TIME" contains from 240506_181307 to 240507_232434
def filter_req_4(root_parts, file_root):
    time_value = root_parts["TIME"]
    start_time = "240506_181307"
    end_time = "240507_232434"

    # Check if time_value falls within the specified range
    return start_time <= time_value <= end_time


# Requirement 4
dataloader_4 = get_dataloader(data, filter_req_4)
save_filtered_data(dataloader_4, 'A1_313831013_游明睿_4.json')

Class: train, Files: 895130 entries
Class: val, Files: 8433 entries
Class: test, Files: 9621 entries


In [24]:
# Requirement 5: "CLASS_NAME" contains "Env3", "GENDER_COUNT" contains exactly 1 male, "POSITION" is "5_posi", and "TIME" is from 240508_090000 to 240508_110000
def filter_req_5(root_parts, file_root):
    # Extract the relevant parts
    class_name = root_parts["CLASS_NAME"]
    gender_count = root_parts["GENDER_COUNT"]
    position = root_parts["POSITION"]
    time_value = root_parts["TIME"]

    # Check conditions
    class_name_match = "Env3" in class_name
    gender_count_match = "Male" in gender_count
    position_match = position == "5_posi"
    start_time = "240508_090000"
    end_time = "240508_110000"
    time_match = start_time <= time_value <= end_time

    # Return True if all conditions match
    return class_name_match and gender_count_match and position_match and time_match


# Requirement 5
dataloader_5 = get_dataloader(data, filter_req_5)
save_filtered_data(dataloader_5, 'A1_313831013_游明睿_5.json')

Class: train, Files: 895130 entries
Class: val, Files: 8433 entries
Class: test, Files: 9621 entries
