In [1]:
import pandas as pd
from utils import util

In [23]:
def get_detaset(label_df, target_emo, negative_threshold=0.5, adjust_samples=None):
    
    '''
    主にCVのためのデータセットを作成する関数
    1. target_emoに応じて、target_videosとnon_target_videosを決定する
    2. target_videoからpositive_samplesを抽出する
    3. non_target_videoからnegative_samplesを抽出する
    4. target_videoからnegative_samplesを抽出する
    5. others_videoからnegative_samplesを抽出する
    上記の処理をpositive_samplesとnegative_samplesの数が同じになるように行う
    
    Args
    ----------
    label_df : DataFrame
        ラベルデータ
    target_emo : str
        positive_samplesを抽出するためのemotion(comfort or discomfort)
    negative_threshold : float
        negative_samples内のnon_target_labelの割合を調整するためのthreshold。0.5の場合は、negative_samples内のnon_target_labelの割合が50%を超えない。
    adjust_samples : tuple(str, int)
        作成したデータセットのpositive_samplesとnegative_samplesの数を調整するための引数。第一引数には、調整するvideo_nameを指定する。第二引数には、調整後のサンプル数を指定する。
        
    Returns
    -------
    df : DataFrame
        作成したデータセット
    video_name_list : DataFrame
        作成したデータセットに含まれるvideoの名前のリスト
    '''
     
    label_df['video_name'], _ = zip(*label_df['img_path'].map(util.get_video_name_and_frame_num))
    video_df_list = {}
    for i in range(1, 26):
        video_df_list['video{}'.format(i)] = label_df[label_df['video_name'] == 'video{}'.format(i)]
        video_df_list['video{}'.format(i)] = video_df_list['video{}'.format(i)].reset_index(drop=True)
    
    #! window_sizeに応じて、使用するvideoの変更が必要
    # comfort_videos = ['video3', 'video9', 'video12', 'video13', 'video8', 'video15', 'video16', 'video17', 'video10', 'video14', 'video11']
    # discomfort_videos = ['video18', 'video19', 'video20', 'video21', 'video22', 'video23', 'video24', 'video25']
    # others_videos = ['video1', 'video2', 'video4', 'video5', 'video6', 'video7']
     
    comfort_videos = []
    discomfort_videos = ['video18', 'video19', 'video20', 'video21', 'video22', 'video23', 'video24']
    others_videos = ['video1', 'video2', 'video4', 'video5', 'video6', 'video7', 'video3', 'video9', 'video12', 'video13', 'video8', 'video15', 'video16', 'video10', 'video11']
    
    df = pd.DataFrame()
    use_video_list = []
    target_videos = None
    non_target_videos = None
    target_label = 0
    not_target_label = 0
    
    if target_emo == 'comfort':
        target_videos = comfort_videos
        non_target_videos = discomfort_videos
        target_label = 1
        not_target_label = 2
    elif target_emo == 'discomfort':
        target_videos = discomfort_videos
        non_target_videos = comfort_videos
        target_label = 2
        not_target_label = 1
        
    for video_name in target_videos:
        
        _df = pd.DataFrame()
        use_videos = [video_name]
        pos_num = 0
        neg_num = 0
        non_target_flag = False
        
        # extract positive samples
        target_df = video_df_list[video_name][video_df_list[video_name]['emotion'] == target_label]
        pos_num += len(target_df)
        print(f'target: {video_name}, len: {len(target_df)}')
        
        # extract negative samples
        if non_target_videos:
            non_target_flag = True
            non_target_video_name = non_target_videos.pop()
            use_videos.append(non_target_video_name)
            non_target_df = video_df_list[non_target_video_name][video_df_list[non_target_video_name]['emotion'] == not_target_label]
            if len(non_target_df) > int(pos_num * negative_threshold):
                if adjust_samples and adjust_samples[0] == non_target_video_name:
                    non_target_df = non_target_df.sample(adjust_samples[1], random_state=0)
                else:
                    non_target_df = non_target_df.sample(int(pos_num * negative_threshold), random_state=0)
            neg_num += len(non_target_df)
            print(f'non_target: {non_target_video_name}, len: {len(non_target_df)}')
        
        # others_df = video_df_list[video_name][video_df_list[video_name]['emotion'] == 0]
        # if len(others_df) > pos_num - neg_num:
        #     others_df = others_df.sample(pos_num - neg_num, random_state=0)
        # neg_num += len(others_df)
        # print(f'others: {video_name}, len: {len(others_df)}')
        
        if non_target_flag:
            _df = pd.concat([target_df, non_target_df])
        else:
            _df = pd.concat([target_df])
        
        while neg_num < pos_num:
            if others_videos:
                add_others_video_name = others_videos.pop()
                use_videos.append(add_others_video_name)
                add_others_df = video_df_list[add_others_video_name][video_df_list[add_others_video_name]['emotion'] == 0]
                if len(add_others_df) > pos_num - neg_num:
                    add_others_df = add_others_df.sample(pos_num - neg_num, random_state=0)
                neg_num += len(add_others_df)
                print(f'others: {add_others_video_name}, len: {len(add_others_df)}')
                _df = pd.concat([_df, add_others_df])
            
            else:
                print('others_videos is empty')
                break

        print(f'pos_num: {pos_num}, neg_num: {neg_num}')
        print()
        
        df = pd.concat([df, _df])
        use_video_list.append(use_videos)
    
    df = df.sort_values('img_path')   
    df = df.reset_index(drop=True)
    
    _use_video_list = [item for sublist in use_video_list for item in sublist]
    _use_video_list = list(set(_use_video_list))
    _use_video_list.sort()
    
    video_name_list = pd.DataFrame(columns=_use_video_list)
    for i in range(len(use_video_list)):
        video_name_list.loc[i, use_video_list[i]] = 1
    video_name_list = video_name_list.fillna(0)
    video_name_list = video_name_list.astype(int)      
            
    return df, video_name_list
        

In [27]:
label_df = pd.read_csv('/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/labels/PIMD_A/drop_discomfort_hardnegative_seq_labels(video1-25)_ver2-gazesign_wsize90-ssize3-th5e-01.csv')

In [28]:
df, video_name_list = get_detaset(label_df, 'discomfort', negative_threshold=0.5)

target: video18, len: 7
others: video11, len: 7
pos_num: 7, neg_num: 7

target: video19, len: 451
others: video10, len: 309
others: video16, len: 142
pos_num: 451, neg_num: 451

target: video20, len: 152
others: video15, len: 152
pos_num: 152, neg_num: 152

target: video21, len: 384
others: video8, len: 286
others: video13, len: 98
pos_num: 384, neg_num: 384

target: video22, len: 68
others: video12, len: 33
others: video9, len: 35
pos_num: 68, neg_num: 68

target: video23, len: 4
others: video3, len: 4
pos_num: 4, neg_num: 4

target: video24, len: 7
others: video7, len: 7
pos_num: 7, neg_num: 7



In [29]:
#! ファイル名の変更忘れずに
df.to_csv('/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/labels/PIMD_A/discomfort-gazesign_labels_wsize90-ssize3.csv', index=False)
video_name_list.to_csv('/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/labels/PIMD_A/discomfort-gazesign_video_name_list_wsize90-ssize3.csv', index=False)