In [1]:
import itertools

import numpy as np

In [2]:
def load_csv(filename):
    data = []
    with open(filename, "r") as f:
        for line in f:
            data.append(tuple(line.strip().split(",")))
    return data

splits = ["train", "test", "query"]
data = {
    s: load_csv(f"../data/market1501_{s}.csv")
    for s in splits
}

In [3]:
len(data["train"])

12936

In [4]:
raw_data = list(itertools.chain.from_iterable(v for v in data.values()))
len(raw_data)

36036

In [5]:
# Remove duplicates
raw_data = list(set(raw_data))
len(raw_data)

36036

In [6]:
# Remove unlabeled data
raw_data = [d for d in raw_data if d[0] not in ["-1", "0000"]]
len(raw_data)

29419

In [7]:
ids = [d[0] for d in raw_data]
unique_ids = list(set(ids))
len(ids), len(unique_ids)

(29419, 1501)

In [8]:
np.random.seed(1123)
np.random.shuffle(unique_ids)

In [9]:
num_gt_m = 375
num_gt_nm = 376
num_sh_m = 375
# num_sh_nm = 375

gt_m_ids = unique_ids[:num_gt_m]
gt_nm_ids = unique_ids[num_gt_m:num_gt_m+num_gt_nm]
sh_m_ids = unique_ids[num_gt_m+num_gt_nm:num_gt_m+num_gt_nm+num_sh_m]
sh_nm_ids = unique_ids[num_gt_m+num_gt_nm+num_sh_m:]

len(gt_m_ids), len(gt_nm_ids), len(sh_m_ids), len(sh_nm_ids)

(375, 376, 375, 375)

In [10]:
np.random.seed(1123)

gt_m = [d for d in raw_data if d[0] in gt_m_ids]
num_gt_tm = len(gt_m) // 2
np.random.shuffle(gt_m)
gt_tm = gt_m[:num_gt_tm]
gt_ntm = gt_m[num_gt_tm:]

gt_nm = [d for d in raw_data if d[0] in gt_nm_ids]

sh_m = [d for d in raw_data if d[0] in sh_m_ids]
num_sh_tm = len(sh_m) // 2
np.random.shuffle(sh_m)
sh_tm = sh_m[:num_sh_tm]
sh_ntm = sh_m[num_sh_tm:]

sh_nm = [d for d in raw_data if d[0] in sh_nm_ids]

len(gt_tm), len(gt_ntm), len(gt_nm), len(sh_tm), len(sh_ntm), len(sh_nm)

(3568, 3569, 7737, 3762, 3762, 7021)

In [11]:
def save_data(filename, data):
    with open(filename, "w") as f:
        f.write('\n'.join([','.join(d) for d in data]))

splitted_data = {
    "ground_truth_training_member": gt_tm,
    "ground_truth_non_training_member": gt_ntm,
    "ground_truth_non_member": gt_nm,
    "shadow_training_member": sh_tm,
    "shadow_non_training_member": sh_ntm,
    "shadow_non_member": sh_nm,
}
for s, d in splitted_data.items():
    save_data(f"../splitted_data/{s}.csv", d)