In [None]:
from glob import glob
import os
from typing import List, Callable, Tuple

h5files = glob("/home_data/home/liyuyang/data2/datasets/CMRxRecon2025/maxsplitset/train/*.h5")
print(f"Found {len(h5files)} h5 files")

def get_rule_counts(fnamelist: list, rule: Callable) -> Tuple[dict, dict]:
    groups = {}
    for fname in fnamelist:
        groupname = rule(fname)
        groups.setdefault(groupname, []).append(fname)
    return groups, {k: len(v) for k, v in groups.items()}

def base_rule(fname: str) -> str:
    basename = fname.split("/")[-1]
    groupname = "@".join(basename.split("@")[:-2])
    return groupname


print()
print("Base groups and counts:")
basegroups, basecounts = get_rule_counts(h5files, base_rule)
for fname, count in basecounts.items():
    print(f"Global group: {fname}, Count: {count}")

# Balance weights based on rules
# 1. 在任意一条规则中，如果该样本所属的组高于百分比，降低权重，低于百分比，增加权重
# 2. 最高4 最小1

def momentum_func(ratio: float) -> float:
    return (ratio - 1) * 0.5 + 1

def balance_weights(weights: dict, rules: List[Tuple[Callable, dict]], max_weight: int = 4, min_weight:int = 1, momentum: Callable = None):
    nsample = len(weights) # Global number of samples
    for fname , weight in weights.items():
        for rule, counts in rules:
            ngroup = len(counts)
            meanvalue = nsample / ngroup
            samplevalue = counts[rule(fname)] # Number of samples in the group to which the sample belongs

            ratio = meanvalue / samplevalue
            if momentum is not None:
                ratio = momentum(ratio)

            weight = weight * ratio
        
        weight = max(min_weight, min(max_weight, weight)) # Clip the weight to the range [min_weight, max_weight]
        weight = round(weight)
        weights[fname] = weight
    return weights

# Rule1: acquisition mode
def acq_rule(fname: str) -> str:
    basename = fname.split("/")[-1]
    groupname = "@".join(basename.split("@")[0:1])
    return groupname

print()
print("Acquisition groups and counts:")
acq_groups, acq_counts = get_rule_counts(h5files, acq_rule)
for fname, count in acq_counts.items():
    print(f"Acquisition group: {fname}, Count: {count}")



# Rule2: where center
def center_rule(fname: str) -> str:
    basename = fname.split("/")[-1]
    groupname = "@".join(basename.split("@")[3:4])
    return groupname

print()
print("Center groups and counts:")
center_groups, center_counts = get_rule_counts(h5files, center_rule)
for fname, count in center_counts.items():
    print(f"Center group: {fname}, Count: {count}")

# Rule3: global fname
def global_rule(fname: str) -> str:
    basename = fname.split("/")[-1]
    groupname = "@".join(basename.split("@")[:-2])
    return groupname

global_groups, global_counts = get_rule_counts(h5files, global_rule)

weights = balance_weights(
    {fname: 1.0 for fname in h5files}, 
    [(base_rule, basecounts), (acq_rule, acq_counts), (center_rule, center_counts), (global_rule, global_counts)],
    max_weight=8, min_weight=1, momentum=momentum_func
)

print()
print("Weights after balancing:")
for fname, weight in weights.items():
    print(f"File: {fname}, Weight: {weight}")

def weight_sum_for_groups(weights: dict, rule: Callable) -> dict:
    groups = {}
    group_weights = {}
    for fname, weight in weights.items():
        groupname = rule(fname)
        groups.setdefault(groupname, []).append(fname)
        group_weights.setdefault(groupname, 0)
        group_weights[groupname] += weight
    return groups, group_weights

print()
print("Final group weights after balancing:")
print(f"totalweights: {sum(weights.values())}")
final_groups, finalweights = weight_sum_for_groups(weights, base_rule)
for groupname, count in finalweights.items():
    print(f"Final group: {groupname}, Total weight: {count}")
