In [None]:
import torch
import torch.nn.functional as F

torch.manual_seed(0)

# Simulated logits
batch_size = 2
vocab_size = 10

logits_q = torch.randn(batch_size, vocab_size)  # "student"
logits_p = torch.randn(batch_size, vocab_size)  # "teacher"

# Compute log-softmax for each
log_q = F.log_softmax(logits_q, dim=-1)  # log Q
log_p = F.log_softmax(logits_p, dim=-1)  # log P

# Global KL using F.kl_div
kl_global = F.kl_div(log_p, log_q, reduction='batchmean', log_target=True)

# Local shards: split vocab in two chunks
shard_1 = slice(0, vocab_size // 2)
shard_2 = slice(vocab_size // 2, vocab_size)

# Compute local KLs with reduction='none', then sum and average
kl_local_1 = F.kl_div(
    log_p[:, shard_1],
    log_q[:, shard_1],
    reduction='batchmean',
    log_target=True,
)  # shape: [B, vocab_shard]
kl_local_2 = F.kl_div(
    log_p[:, shard_2],
    log_q[:, shard_2],
    reduction='batchmean',
    log_target=True,
)
print(kl_local_2)
# Combine local losses
kl_manual = (kl_local_1 + kl_local_2)
# ✅ Print to verify
print(f"KL (global, F.kl_div)        = {kl_global.item():.6f}")
print(f"KL (sum of shards, manual)  = {kl_manual.item():.6f}")

tensor(0.6451)
KL (global, F.kl_div)        = 0.738795
KL (sum of shards, manual)  = 0.738795
