In [205]:
import torch
import json
from torch.utils.data import Dataset, DataLoader

In [206]:
# 想把這個做成通用Class, 所以多了幾個額外funcion
class CSI_Dataset(Dataset):
    # 這邊用num不用count當參數, 是怕和str.count()搞混, 所以使用num, gender則是用來指定男或女
    # 一樣利用time_list來取出特定時間, 主要是放開始和結束時間
    def __init__(self, split='train', mask_list=[], gender='F', num=-1, time_list=[]):
        self.split = split
        with open('./CSI_data.json', 'r') as file:
            self.json_data = json.load(file)

        self.data_split = self.json_data[split]
        
        if len(mask_list):
            # 確定mask list 正確
            print(f'have mask: {mask_list}')
            data_mask = self.data_split
            # 原先的版本如下, 會造成mask完後有機率有重複數據, 所以稍作修改
            ########################################################################################################################
            # data_mask = []
            # for mask in mask_list:
            #   data_mask.extend(self.data_split[index] for index in range(len(self.data_split)) if mask in self.data_split[index])
            ########################################################################################################################
            
            # 修改後的版本, 此版本有另一個缺點是會把所有mask list內的mask一起做
            # #所以只有完整符合所有mask list的數據會留下
            ########################################################################################################################
            for mask in mask_list:
                data_mask = [data for data in data_mask if mask in data]
            ########################################################################################################################
            self.data = data_mask
        else:
            print(f'No mask: {mask_list}')
            self.data = self.data_split
        
        if num > 0:
            self.__gendercount__(num,gender)
        
        if len(time_list):
            self.__time_compare__(time_list)
            
    def __len__(self):
        return len(self.data)
            
    def __getitem__(self, idx):
        return self.data[idx]
    
    def __gendercount__(self,num,gender):
        data_count = []
        # 這邊用count來計算gender數量, 可以很方便地計算出contain number.
        data_count.extend(name for name in self.data if name.count(gender) == num)
        self.data = data_count
                
    def __time_compare__(self, time_list):
        start_date = int(time_list[0].split('_')[0])
        start_time = int(time_list[0].split('_')[1])
        end_date = int(time_list[1].split('_')[0])
        end_time = int(time_list[1].split('_')[1])
        print(start_date, start_time, end_date, end_time)

        data_timestep = []
        # 比較日期是不是在範圍內
        for data in self.data:
            data_date = int(data.split('/')[4].split('_')[0])
            data_time = int(data.split('/')[4].split('_')[1])
            # 這樣寫是因為測資會有日期相同時間不同的情況, 所以先比較這個例外狀況
            if start_date == end_date:
                if data_date == start_date and data_time >= start_time and data_time <= end_time:
                    data_timestep.append(data)
            else:
                if data_date == start_date and data_time >= start_time or data_date == end_date and data_time <= end_time:
                    data_timestep.append(data)
                elif data_date > start_date and data_date < end_date:
                    data_timestep.append(data)
        self.data = data_timestep
        

In [207]:
def Write_Output(output,filename='hw1'):
    data_dict = {i: value for i, value in enumerate(output)}

    with open(f'./A1_313834004_周彥宏_{filename}.json', 'w') as json_file:
        json.dump(data_dict, json_file, indent=4) # indent為美化輸出, 很像pretty.print in json output

# Requirement 1

In [208]:
dataset = CSI_Dataset(split='train',mask_list=['Env3'])
dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
print(dataset.__len__())

have mask: ['Env3']
158110


In [209]:
output = next(iter(dataloader))
output.__len__()

158110

In [210]:
# write the output file
Write_Output(output=output, filename='1')

# Requirement 2
In val_set and test_set not contain over 2 female, only consider train set.

In [211]:
# First step to split dataset from train and mash all Female
dataset = CSI_Dataset(split='train',mask_list=['F'], gender='F', num=2)

dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
print(dataset.__len__())

have mask: ['F']
128820


In [212]:
output = next(iter(dataloader))
output[128500:128819]

['Env5/npy/F3M1F1/5_posi/240508_220633/1715177242836864593',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177242887880453',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177242937330041',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177242988022979',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177243037473498',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177243088516164',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177243140173383',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177243190397998',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177243241374825',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177243292252306',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177243342614390',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177243391591019',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177243442398051',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177243492942674',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177243542950105',
 'Env5/npy/F3M1F1/5_posi/240508_220633/1715177243592651699',
 'Env5/npy/F3M1F1/5_posi

In [213]:
# write the output file
Write_Output(output=output, filename='2')

# Requirement 3
In this dataset, if only contain 1 female, the format be "Female" not "F", so we just need to mask "Female".<br>
BTW, in val_set and test_set only contain 1 Female or 1 Male, we need to combine val_set, train_set and test_set to record.

In [214]:
dataset_train = CSI_Dataset(split='train', mask_list=['Female'])
dataset_val = CSI_Dataset(split='val', mask_list=['Female'])
dataset_test = CSI_Dataset(split='test', mask_list=['Female'])

print(len(dataset_train), len(dataset_val), len(dataset_test))
dataLoader_train = DataLoader(dataset_train, batch_size=len(dataset_train), shuffle=False)
dataLoader_val = DataLoader(dataset_val, batch_size=len(dataset_val), shuffle=False)
dataLoader_test = DataLoader(dataset_test, batch_size=len(dataset_test), shuffle=False)

have mask: ['Female']
have mask: ['Female']
have mask: ['Female']
229679 3622 4807


In [215]:
# combine all dataset to output, and check it contain all of "Female" 
output = []
output.extend(next(iter(dataLoader_train)))
output.extend(next(iter(dataLoader_val)))
output.extend(next(iter(dataLoader_test)))
print(len(output))
output[233001:233301]

238108


['val_set/npy/Female3/rand_posi/240509_111828/1715224764352084857',
 'val_set/npy/Female3/rand_posi/240509_111828/1715224764397858836',
 'val_set/npy/Female3/rand_posi/240509_111828/1715224764444951954',
 'val_set/npy/Female3/rand_posi/240509_111828/1715224764492160867',
 'val_set/npy/Female3/rand_posi/240509_111828/1715224764537815576',
 'val_set/npy/Female3/rand_posi/240509_111828/1715224764584097742',
 'val_set/npy/Female3/rand_posi/240509_111828/1715224764630081099',
 'val_set/npy/Female3/rand_posi/240509_111828/1715224764677552059',
 'val_set/npy/Female3/rand_posi/240509_111828/1715224764724263855',
 'val_set/npy/Female3/rand_posi/240509_111828/1715224764771677353',
 'val_set/npy/Female3/rand_posi/240509_111828/1715224764819196479',
 'val_set/npy/Female3/rand_posi/240509_111828/1715224764866966869',
 'val_set/npy/Female3/rand_posi/240509_111828/1715224764914710796',
 'val_set/npy/Female3/rand_posi/240509_111828/1715224764962805620',
 'val_set/npy/Female3/rand_posi/240509_111828/17

In [216]:
# write the output file
Write_Output(output=output, filename='3')

# Requriement 4
In this case, I want to use mask for y/m/d time, and to compare the h/m/s time.<br>
In val_set and test_set, time start at 240509, so i ignore val_set and test_set.<br>
Hit: 240509 isn't only in time format, also in filename. So I want to use split '/' to get the timestep.

In [217]:
# 這邊實驗一下 要怎麼用這個format 寫出time compare
s = 'Env5/npy/F2M3F3/5_posi/240508_231251/1715181210784950271'
# get timestep
print(s.split('/')[4])
# get h/m/s
print(s.split('/')[4].split('_')[1])


240508_231251
231251


In [218]:
dataset = CSI_Dataset(split='train',time_list=['240506_181307', '240507_232434'])
print(dataset.__len__())

dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
output = next(iter(dataloader))
# 和原始data比較, 看是不是正確
print(output[0],output[391066])


No mask: []
240506 181307 240507 232434


391067
Env0/npy/F1M1F2/5_posi/240506_223137/1715005908730193543 Env2/npy/None/1_posi/240507_174339/1715075149570053260


In [219]:
# write the output file
Write_Output(output=output, filename='4')

# Requirement 5
In this case only need data in trian set, bcz the "Position":"5_posi" & "Class_name":"Env3" is only in train set.<br>
And we need to contain just one male means "Male(x)". So I decide to use mask for "Male". 

In [220]:
dataset = CSI_Dataset(split='train',mask_list=['Env3','5_posi','Male'],time_list=['240508_090000', '240508_110000'])
print(dataset.__len__())
print(dataset[0])
dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
output = next(iter(dataloader))

have mask: ['Env3', '5_posi', 'Male']
240508 90000 240508 110000
18481
Env3/npy/Male1/5_posi/240508_100726/1715134057824090344


In [221]:
# write the output file
Write_Output(output=output, filename='5')