# データセットのバランス

いずれかのデータのほうが多い場合は、学習が崩れてしまうので、ある程度に各種のカテゴリーから同じデータ量があると良い

In [1]:
import random

# データセットを読み込み、余計なデータを削除
def cleanup_dataset():

    import csv
    
    raw_dataset = []

    csv_file = open("data/password_strength_raw.csv", newline='', encoding="utf-8_sig")
    reader = csv.reader(csv_file)

    for line in reader:
         
        strength = line[0]
        password = line[1]
        
        if not strength.isnumeric():
            continue

        if not isinstance(password, str):
            password = str(password)
        
        if password == "":
            continue
            
        raw_dataset.append((password, int(strength)))

    return raw_dataset



In [2]:
# データを読み込む
cleanup = cleanup_dataset()

In [3]:
print(cleanup[0:50])

[('123456', 0), ('12345', 0), ('123456789', 0), ('password', 0), ('iloveyou', 0), ('princess', 0), ('1234567', 0), ('rockyou', 1), ('12345678', 0), ('abc123', 0), ('nicole', 0), ('daniel', 0), ('babygirl', 0), ('monkey', 0), ('lovely', 0), ('jessica', 0), ('654321', 0), ('michael', 0), ('ashley', 0), ('qwerty', 0), ('111111', 0), ('iloveu', 0), ('000000', 0), ('michelle', 0), ('tigger', 0), ('sunshine', 0), ('chocolate', 1), ('password1', 0), ('soccer', 0), ('anthony', 0), ('friends', 0), ('butterfly', 0), ('purple', 0), ('angel', 0), ('jordan', 0), ('liverpool', 0), ('justin', 0), ('loveme', 0), ('fuckyou', 0), ('123123', 0), ('football', 0), ('secret', 0), ('andrea', 0), ('carlos', 0), ('jennifer', 0), ('joshua', 0), ('bubbles', 0), ('1234567890', 0), ('superman', 0), ('hannah', 0)]


In [4]:
# バランスを取る
def balance_dataset(raw_dataset):
    
    # 0, 1, 2, 3, 4の各種のパスワードを分ける
    sub_dataset = [[], [], [], [], []]
    for pwd in raw_dataset:
        sub_dataset[pwd[1]].append(pwd)
        
    # 最も少ないグループをはどれ？
    max_count = len(sub_dataset[0])
    max_count = min(max_count, len(sub_dataset[1]))
    max_count = min(max_count, len(sub_dataset[2]))
    max_count = min(max_count, len(sub_dataset[3]))
    max_count = min(max_count, len(sub_dataset[4]))
    
    # 少し余裕(25%)を作る
    max_count *= 1.25
    
    # バランス前の状態
    print(len(sub_dataset[0]), len(sub_dataset[1]), len(sub_dataset[2]), len(sub_dataset[3]), len(sub_dataset[4]))
    
    # オーバーしているグループからデータを消す
    for i in range(5):
        
        data_len = len(sub_dataset[i])
        while (data_len > max_count):

            # ランダムの範囲を選択（最大200個)
            min_idx = random.randrange(data_len)
            max_idx = min_idx + random.randrange(1, 200)
            max_idx = min(max_idx, data_len)
            
            # 削除
            del sub_dataset[i][min_idx:max_idx]
            
            # グループの長さを求める
            data_len = len(sub_dataset[i])
            print("\r" + str(data_len) + "            ", end="")

        print("")
        
    # バランスあとの状態
    print(len(sub_dataset[0]), len(sub_dataset[1]), len(sub_dataset[2]), len(sub_dataset[3]), len(sub_dataset[4]))
    
    # 全のグループを合成
    merged =  sub_dataset[0]
    merged += sub_dataset[1]
    merged += sub_dataset[2]
    merged += sub_dataset[3]
    merged += sub_dataset[4]
    
    # シャッフル
    random.shuffle(merged)
    return merged

In [5]:
balance = balance_dataset(cleanup)

29903 5116664 5214085 3026092 957646

37336              
37280              
37241              
37320             
29903 37336 37280 37241 37320


In [6]:
print(balance[0:50])

[('nahpets', 0), ('sasa2727', 1), ('lilboost', 2), ('4408tiiyt', 3), ('schuessler', 3), ('1033704045', 3), ('smurfy23', 2), ('milagrsy gernys', 4), ('7h0mp50n', 0), ('wutsuphomeboy', 3), ('TERESA', 0), ('hoddboty1', 3), ('patty', 0), ('lucky038', 1), ('Tummykiwi151', 4), ('pipsonly1', 3), ('akamaru16', 3), ('45192480093', 3), ('smurf26', 2), ('25761838', 2), ('4snickers', 1), ('informacion(29', 4), ('ams', 0), ('2890427226702', 4), ('vivalavida72', 4), ('84671395', 2), ('141079101104', 4), ('930221065280*', 4), ('jhahjhab', 2), ('chaminda', 2), ('Cilit123', 2), ('la', 0), ('espin0za', 0), ('?*?+&^&', 2), ('dixiedixie', 0), ('vtiscool1', 3), ('iraronnel', 3), ('hugolasso1982', 4), ('0802405954', 3), ('agudelorosero924', 4), ('a2353885', 2), ('5288610', 2), ('42dodge', 1), ('STUBBY', 1), ('slubber1', 2), ('santo5', 0), ('CALLME0103', 3), ('rockyou109', 2), ('085223420132', 3), ('brad784ever388', 4)]


In [7]:
# データを保存
def save_dataset(dataset):

    import csv
    
    raw_dataset = []

    csv_file = open("data/password_strength.csv", 'w', newline='', encoding="utf-8_sig")
    writer = csv.writer(csv_file)

    for item in dataset:   
        writer.writerow(item)
        
save_dataset(balance)