In [22]:
from typing import List, Dict, Sequence, Iterable, Tuple
import random
import torch


# ---------------------------
# Utilities and containers
# ---------------------------

class PackedBatch:
    """Holds variable-length samples before padding into tensors."""
    def __init__(self):
        self.input_ids: List[torch.Tensor] = []
        self.attention_mask: List[torch.Tensor] = []
        self.labels: List[torch.Tensor] = []

    def max_length(self) -> int:
        if not self.input_ids:
            return 0
        return max(t.size(0) for t in self.input_ids)

    def to_dict(self) -> Dict[str, torch.Tensor]:
        """Pad samples into a fixed-size batch dictionary."""
        if not self.input_ids:
            return {
                "input_ids": torch.empty(0, dtype=torch.long),
                "attention_mask": torch.empty(0, dtype=torch.long),
                "labels": torch.empty(0, dtype=torch.long),
            }

        batch_size = len(self.input_ids)
        max_len = self.max_length()

        input_ids_tensor = torch.zeros((batch_size, max_len), dtype=torch.long)
        attention_mask_tensor = torch.zeros((batch_size, max_len), dtype=torch.long)
        labels_tensor = torch.full((batch_size, max_len), -100, dtype=torch.long)

        for i, (ids, mask, lbls) in enumerate(zip(self.input_ids, self.attention_mask, self.labels)):
            L = ids.size(0)
            input_ids_tensor[i, :L] = ids
            attention_mask_tensor[i, :L] = mask
            labels_tensor[i, :L] = lbls

        return {
            "input_ids": input_ids_tensor,
            "attention_mask": attention_mask_tensor,
            "labels": labels_tensor,
        }


def summarize_batches(batches: Sequence[Dict[str, torch.Tensor]], label: str) -> Tuple[int, int]:
    total_tokens = sum(batch["input_ids"].numel() for batch in batches)
    useful_tokens = sum((batch["labels"] != -100).sum().item() for batch in batches)
    efficiency = (useful_tokens / total_tokens) if total_tokens else 0.0
    if label:
        print(
            f"[opensloth.ddp_patch] {label}: total_tokens={total_tokens} "
            f"useful_tokens={useful_tokens} efficiency={efficiency:.2%}"
        )
    return total_tokens, useful_tokens


# ---------------------------
# Sample extraction
# ---------------------------

def extract_samples(batches: Iterable[Dict[str, torch.Tensor]]) -> List[Dict[str, torch.Tensor]]:
    """
    Flatten a list of padded batches into per-sample dicts with true (unpadded) length.
    Sorts samples by length descending (longest-first).
    """
    samples: List[Dict[str, torch.Tensor]] = []
    for batch in batches:
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        lengths = attention_mask.sum(dim=1).tolist()

        for row, length in enumerate(lengths):
            length = int(length)
            samples.append(
                {
                    "input_ids": input_ids[row, :length],
                    "attention_mask": attention_mask[row, :length],
                    "labels": labels[row, :length],
                    "length": length,
                }
            )

    # Sort by sequence length descending (longest first)
    samples.sort(key=lambda item: item["length"], reverse=True)
    return samples


# ---------------------------
# Optimal repacking (DP, contiguous groups)
# ---------------------------

def _optimal_partition_by_length(lengths: List[int], k: int) -> List[Tuple[int, int]]:
    """
    Partition a descending-sorted length array into k contiguous, non-empty groups
    minimizing sum over groups of (max_length_in_group * group_size).
    Returns list of (start_idx, end_idx) inclusive, covering 0..n-1.
    DP complexity: O(n * k^2) worst-case (fine for typical batch sizes).
    """
    n = len(lengths)
    if k <= 0:
        raise ValueError("k must be >= 1")
    if n < k:
        # It's impossible to have k non-empty groups if we have fewer samples than groups.
        # Caller ensures batches come from real dataloaders, so assert loudly here.
        raise ValueError(f"Not enough samples ({n}) to fill {k} non-empty batches")

    INF = 10**18
    # dp[t][j]: min cost to partition 0..j into t groups (1-index t)
    dp = [[INF] * n for _ in range(k)]
    prev = [[-1] * n for _ in range(k)]  # prev split index s for dp[t][j]: last group is s+1..j

    # Base: t=1 -> one group: cost is lengths[0] * (j+1)
    for j in range(n):
        dp[0][j] = lengths[0] * (j + 1)
        prev[0][j] = -1  # start at 0

    # Fill DP
    for t in range(1, k):  # groups 2..k
        # We need at least t items to form t groups (each non-empty), so j >= t
        for j in range(t, n):
            best_cost = INF
            best_s = -1
            # s is end index of previous partition; last group is s+1..j, so s ∈ [t-2 .. j-1]
            s_min = t - 2
            if s_min < -1:
                s_min = -1
            for s in range(max(s_min, -1), j):
                # cost of previous t groups on 0..s, plus cost of group (s+1..j)
                # since lengths sorted desc, max of group (s+1..j) is lengths[s+1]
                group_size = j - (s + 1) + 1  # = j - s
                cost = (dp[t - 1][s] if s >= 0 else INF) + lengths[s + 1] * group_size
                if cost < best_cost:
                    best_cost = cost
                    best_s = s
            dp[t][j] = best_cost
            prev[t][j] = best_s

    # Reconstruct boundaries
    bounds: List[Tuple[int, int]] = []
    t = k - 1
    j = n - 1
    while t >= 0:
        s = prev[t][j]
        start = s + 1 if t > 0 else 0
        bounds.append((start, j))
        j = s
        t -= 1
    bounds.reverse()
    return bounds


def repack_batches(
    batches: Sequence[Dict[str, torch.Tensor]],
    verbose: bool = False
) -> List[Dict[str, torch.Tensor]]:
    """
    Optimal repacking (given fixed number of output batches and non-empty constraint).
    1) Extract and sort samples by length (desc).
    2) Use DP to partition into K contiguous groups minimizing sum(max_len * group_size).
    3) Assign groups to K new PackedBatch containers.
    """
    K = len(batches)
    samples = extract_samples(batches)
    lengths = [s["length"] for s in samples]
    bounds = _optimal_partition_by_length(lengths, K)

    if verbose:
        print("[opensloth.ddp_patch] DP group boundaries (start,end,len,max):")
    packed_batches = [PackedBatch() for _ in range(K)]

    for bi, (lo, hi) in enumerate(bounds):
        # max len inside the group is at 'lo' (descending order)
        if verbose:
            group_max = samples[lo]["length"]
            print(f"  Batch[{bi}] = [{lo}:{hi}] "
                  f"(size={hi-lo+1}, max={group_max})")
        target = packed_batches[bi]
        for idx in range(lo, hi + 1):
            s = samples[idx]
            target.input_ids.append(s["input_ids"])
            target.attention_mask.append(s["attention_mask"])
            target.labels.append(s["labels"])

    # Convert to padded tensors
    return [pb.to_dict() for pb in packed_batches]


# ---------------------------
# Mock data for quick testing
# ---------------------------

def make_mock_batches(
    num_batches: int,
    batch_size: int,
    max_len: int,
    seed: int = 42
) -> List[Dict[str, torch.Tensor]]:
    rng = random.Random(seed)
    batches = []
    for _ in range(num_batches):
        lengths = [rng.randint(5, max_len) for _ in range(batch_size)]
        Lmax = max(lengths)
        input_ids = torch.zeros((batch_size, Lmax), dtype=torch.long)
        attention_mask = torch.zeros((batch_size, Lmax), dtype=torch.long)
        labels = torch.full((batch_size, Lmax), -100, dtype=torch.long)

        for i, L in enumerate(lengths):
            input_ids[i, :L] = torch.arange(L) % 1000
            attention_mask[i, :L] = 1
            labels[i, :L] = torch.arange(L) % 1000

        batches.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels})
    return batches


# ---------------------------
# Demo
# ---------------------------

if __name__ == "__main__":
    # Create toy input: same number of output batches must be kept and none empty
    orig_batches = make_mock_batches(num_batches=16, batch_size=4, max_len=500)
    summarize_batches(orig_batches, "Before repack")

    # Optimal (DP) repack with the same number of batches
    new_batches = repack_batches(orig_batches, verbose=True)
    summarize_batches(new_batches, "After repack")


[opensloth.ddp_patch] Before repack: total_tokens=22180 useful_tokens=14099 efficiency=63.57%
[opensloth.ddp_patch] DP group boundaries (start,end,len,max):
  Batch[0] = [0:0] (size=1, max=495)
  Batch[1] = [1:3] (size=3, max=461)
  Batch[2] = [4:8] (size=5, max=419)
  Batch[3] = [9:13] (size=5, max=384)
  Batch[4] = [14:16] (size=3, max=364)
  Batch[5] = [17:18] (size=2, max=337)
  Batch[6] = [19:23] (size=5, max=314)
  Batch[7] = [24:26] (size=3, max=284)
  Batch[8] = [27:28] (size=2, max=240)
  Batch[9] = [29:32] (size=4, max=221)
  Batch[10] = [33:36] (size=4, max=188)
  Batch[11] = [37:42] (size=6, max=147)
  Batch[12] = [43:47] (size=5, max=119)
  Batch[13] = [48:51] (size=4, max=86)
  Batch[14] = [52:57] (size=6, max=57)
  Batch[15] = [58:63] (size=6, max=27)
[opensloth.ddp_patch] After repack: total_tokens=14522 useful_tokens=14099 efficiency=97.09%


tensor([ 916,  208, 2255, 2008, 1831, 1146,  843, 4467,  716, 4837, 3457,  264,
         248,  771, 1794, 1908, 4139, 4931,  221, 4597, 1631, 4464, 3437, 1808,
        3680, 4827, 2280,   57, 1310, 3463, 2789, 2278, 1276, 1766, 2759,  841,
         763, 3113,  796, 2942, 2819, 4945, 2168,  359, 3764, 4392, 1025, 3101,
         649, 4522])