Skip to content

Add group loss balancing for calibration targets #30

@baogorek

Description

@baogorek

Problem

When calibrating with targets that have different cardinalities (e.g., mixing national-level totals with state-level age distributions), the loss function can become dominated by high-cardinality groups. For example:

  • 3 national-level targets (population, income, employment)
  • 50 state-level age distributions with 18 bins each (900 targets total)

Without balancing, the 900 age distribution targets would dominate the loss, causing poor calibration for the important national-level targets.

Proposed Solution

Add support for group-wise loss averaging in the calibration module, where targets can be assigned to groups and each group contributes equally to the total loss, regardless of how many targets it contains.

This is particularly important for calibration tasks where:

  • Targets represent histograms or distributions (multiple bins for one conceptual target)
  • Mixed granularity targets (national vs state vs county level)
  • Some targets are naturally grouped (e.g., all age bins for a given geography)

Key Difference from Issue #5

This issue (#30): Group loss balancing - ensuring groups of related targets (like histogram bins) contribute equally to the optimization loss, preventing high-cardinality groups from dominating.

Issue #5: Group sparsity - enforcing structured sparsity patterns where entire groups of parameters are zeroed out together (e.g., all weights connecting to a neuron).

Example Use Case

# National targets (3 singletons)
national_targets = [total_population, total_income, total_employment]

# State age distributions (18 bins × 50 states = 900 targets)  
state_age_targets = [age_0_5, age_5_10, ..., age_85_plus] × 50_states

# Without group balancing: 900 age targets dominate the 3 national targets
# With group balancing: each state's age distribution counts as 1 unit, 
# so we have 3 national + 50 state groups contributing equally

Implementation Requirements

  • Add target_groups parameter to specify which targets belong together
  • Implement weighted loss calculation where each group gets weight 1/group_size
  • Ensure backwards compatibility (no groups = original behavior)
  • Add tests for various grouping scenarios
  • Update verbose output to show group-aware metrics

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions