In [1]:
# allreduce_tensor.py
import torch, time
import hivemind

# 1) Start or join the DHT (peer discovery/coordination)
#    First peer: dht = hivemind.DHT(start=True)
#    Other peers: put the printed address from the first peer into initial_peers=["..."]
# dht = hivemind.DHT(start=True, client_mode = False)
# print("Share this with other peers:", [str(a) for a in dht.get_visible_maddrs()])

dht = hivemind.DHT(start = True, initial_peers= ['/ip4/127.0.0.1/tcp/41001/p2p/12D3KooWJbtD23NdFUF7wFCFx6Jz2QTW7C6jM9LmGFgpe4cW4s4Y'])
print("Share this with other peers:", [str(a) for a in dht.get_visible_maddrs()])

Share this with other peers: ['/ip4/127.0.0.1/tcp/34181/p2p/12D3KooWN3ZMXJNk753iCxATNgQ9FRnCuoxpDBpwJ5okdp6J9d68']


In [2]:

# 2) Make a tensor you'd like to all-reduce (sum/average)
local = torch.ones(4) * (torch.randint(1, 10, ()).item())  # e.g., [k, k, k, k]
print("local before:", local.tolist())

# 3) Create an averager for that tensor; all peers must use the SAME prefix
averager = hivemind.averaging.DecentralizedAverager(
    averaged_tensors=[local], dht=dht, start=True, prefix="demo/allreduce", target_group_size=4
)

local before: [8.0, 8.0, 8.0, 8.0]


Agg A modes (<AveragingMode.NODE: 0>, <AveragingMode.NODE: 0>)
agg B <function load_balance_peers at 0x70ae0475ca60> 4 [None, None] 0
modes A None
modes B (2, 2)
modes C (<AveragingMode.NODE: 0>, <AveragingMode.NODE: 0>)


In [3]:
local += torch.ones(4) * (torch.randint(1, 10, ()).item())
print("local before:", local.tolist())

# 4) Run one all-reduce round (blocks until a group forms or times out)
#    By default, it computes the *average* in-place; set averaging_alpha=1 to overwrite with the average.
#    You can pass weight=<float> to do a weighted average.
info = averager.step(timeout=30.0, gather = {'step': 21})
print("group info:", info)

# After step(), `local` now holds the averaged values from all peers in the group.
print("local after:", local.tolist())

local before: [14.0, 14.0, 14.0, 14.0]
group info: {<libp2p.peer.id.ID (12D3KooWRGwyJCxXnqWxWRpLzX15T8qXiYknBpWBUM7g4iVa6woR)>: {'step': 20}, <libp2p.peer.id.ID (12D3KooWN3ZMXJNk753iCxATNgQ9FRnCuoxpDBpwJ5okdp6J9d68)>: {'step': 21}}
local after: [11.5, 11.5, 11.5, 11.5]


In [4]:

averager.shutdown()
dht.shutdown()
