In [1]:
import json

In [16]:
PITCH_DICT = {
    0: "HE",
    1: "SI",
    2: "YI",
    3: "SHANG",
    4: "GOU",
    5: "CHE",
    6: "GONG",
    7: "FAN",
    8: "LIU",
    9: "WU",
    10: "GAO_WU",
}
SECONDARY_DICT = {0: None, 1: "DA_DUN", 2: "XIAO_ZHU", 3: "DING_ZHU", 4: "DA_ZHU", 5: "ZHE", 6: "YE"}


class_to_annotation_dict = {}
idx = 0
for secondary in SECONDARY_DICT.keys():
    for pitch in PITCH_DICT.keys():
        class_to_annotation_dict[idx] = f"{PITCH_DICT[pitch]}, {SECONDARY_DICT[secondary]}"
        idx += 1

class_to_subclasses_dict = {}
idx = 0
for secondary in SECONDARY_DICT.keys():
    for pitch in PITCH_DICT.keys():
        class_to_subclasses_dict[idx] = (pitch, secondary)
        idx += 1


def _class_to_annotation(class_idx):
    return class_to_annotation_dict[class_idx]


def _annotation_to_class(annotation):
    for key in class_to_annotation_dict.keys():
        if class_to_annotation_dict[key] == annotation:
            return key
    raise Exception("Invalid annotation", annotation)


def properties_to_class(pitch, secondary):
    return _annotation_to_class(f"{pitch}, {secondary}")


def class_to_subclasses(class_idx):
    return class_to_subclasses_dict[class_idx]


def individual_labels_to_class(pitch, secondary):
    return properties_to_class(PITCH_DICT[int(pitch)], SECONDARY_DICT[int(secondary)])

66


In [49]:
def partition_by_annotation(l):
    per_annotation = {}
    for element_idx, element in enumerate(l):
        class_idx = properties_to_class(element["annotation"]["pitch"], element["annotation"]["secondary"])
        element["IDX"] = element_idx
        if class_idx not in per_annotation:
            per_annotation[class_idx] = []
        per_annotation[class_idx].append(element)
    return per_annotation


with open("dataset.json") as json_file:
    f = json.load(json_file)

per_annotation = partition_by_annotation(f)

samples = per_annotation.values()
samples = sorted(samples, key=lambda x: len(x), reverse=True)

samples_1 = []
samples_2 = []

for idx, s in enumerate(samples):
    half1 = s[: len(s) // 2]
    half2 = s[len(s) // 2 :]
    if len(half1) > 1 and idx % 2:
        half1, half2 = half2, half1
    else:
        if len([item for row in samples_1 for item in row]) < len([item for row in samples_2 for item in row]):
            half2, half1 = half1, half2

    samples_1.append(half1)
    samples_2.append(half2)

samples_1 = [item for row in samples_1 for item in row]
samples_2 = [item for row in samples_2 for item in row]

print(len(samples_1), len(samples_2))

for group, l in enumerate((samples_1, samples_2)):
    for idx in range(len(l)):
        l[idx]["group"] = group

all_samples = samples_1 + samples_2

all_samples = sorted(all_samples, key=lambda x: x["IDX"])

print([len(l) for l in partition_by_annotation(samples_1).values()])
print([len(l) for l in partition_by_annotation(samples_2).values()])

print(len(f), len(all_samples))

print(all_samples[0:10])

with open("dataset_grouped.json", "w") as json_file:
    json.dump(all_samples, json_file)

719 720
[84, 81, 60, 59, 57, 53, 39, 34, 21, 15, 14, 14, 11, 11, 10, 10, 9, 9, 8, 8, 8, 7, 7, 6, 5, 5, 4, 5, 4, 5, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[84, 81, 60, 59, 57, 52, 40, 33, 22, 15, 14, 14, 11, 10, 10, 10, 10, 9, 9, 8, 7, 7, 7, 5, 6, 5, 5, 4, 5, 4, 5, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1]
1439 1439
[{'image_path': 'images/shanghai_manuscript_035_030_geximeiling_18.png', 'type': 'Music', 'notation_type': 'Suzipu', 'annotation': {'pitch': 'SHANG', 'secondary': 'XIAO_ZHU'}, 'edition': 'shanghai', 'filename': '030_geximeiling.json', 'text_annotation': '好', 'IDX': 668, 'group': 1}, {'image_path': 'images/shanghai_manuscript_035_030_geximeiling_19.png', 'type': 'Music', 'notation_type': 'Suzipu', 'annotation': {'pitch': 'GONG', 'secondary': None}, 'edition': 'shanghai', 'filename': '030_geximeiling.json', 'text_annotation': '花', 'IDX': 84, 'group': 1}, {'image_path': 'images/shanghai_manuscript_035_030_geximei