In [1]:
import hivemind
from hivemind.optim.experimental.grad_averager import GradientAverager
from hivemind.optim.experimental.factorized_averager import FactorizedGradientAverager
from hivemind.optim.experimental.power_sgd_averager import PowerSGDAverager

In [2]:
dht_root = hivemind.DHT(start=True)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST

In [4]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_data = MNIST(".", download=True, transform=transform)

In [5]:
class SmallCNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(1, 4, (5, 5)),
            nn.ReLU(),
            nn.Conv2d(4, 16, (5, 5)),
            nn.ReLU(),
            nn.Conv2d(16, 64, (5, 5)),
            nn.ReLU(),
            nn.Conv2d(64, 64, (5, 5)),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.cls = nn.Sequential(
            nn.Linear(64 * 6 * 6, 400),
            nn.ReLU(),
            nn.Linear(400, 10)
        )

    def forward(self, x):
        feature = self.features(x)
        return self.cls(feature.view(x.size(0), -1))



In [6]:
import threading
import os

class Peer(threading.Thread):
    def __init__(self, *, start: bool):
        super().__init__(daemon=True)
        self.dht = hivemind.DHT(initial_peers=dht_root.get_visible_maddrs(), start=True)
        self.model = SmallCNN()
        for param in self.model.parameters():
            param.grad = torch.zeros_like(param).share_memory_()

#         self.averager = GradientAverager(
#             self.model.parameters(), dht=self.dht, target_group_size=2, prefix='my_mega_exp', start=True,
#         )
        self.averager = FactorizedGradientAverager(
            self.model.parameters(), 2, dht=self.dht, target_group_size=2, prefix='my_mega_exp', start=True,
        )
#         self.averager = PowerSGDAverager(
#             self.model.parameters(), 2, dht=self.dht, target_group_size=2, prefix='my_mega_exp', start=True,
#         )
        if start:
            self.start()
        
    def run(self):
        print('started', self.dht.peer_id)
        train_dataloader = torch.utils.data.DataLoader(train_data, num_workers=0, batch_size=32, shuffle=True)
        opt = torch.optim.SGD(self.model.parameters(), lr=0.01)
        
        for i, (xb, yb) in enumerate(train_dataloader):
            logits = self.model(xb)
            loss = F.cross_entropy(logits, yb)

            opt.zero_grad()
            loss.backward()
            
            self.averager.accumulate_grads_(batch_size=32)

            self.averager.step()
            with self.averager.use_averaged_gradients():
                opt.step()
            self.averager.reset_accumulated_grads_()
            
            if i % 10 == 0:
                print(i, self.dht.peer_id.pretty()[-3:], loss.item())
            if i > 500: break
        
        



In [None]:
peers = [Peer(start=False), Peer(start=False)]

peers[1].model.load_state_dict(peers[0].model.state_dict())


for peer in peers:
    peer.start()
for p in peers:
    p.join()

started QmVH7UogsEXqFT9jhdDf2o34LkCVxzt176DUER2iHF98Fg
started QmVbZ5zzp3g7TbzbY6BuAt16RH6oWZXoshSH58SCGvyo77
0 o77 2.3087573051452637
0 8Fg 2.302330732345581
10 8Fg 2.3007829189300537
10 o77 2.300511360168457
20 8Fg 2.292266368865967
20 o77 2.310678482055664
30 8Fg 2.2952449321746826
30 o77 2.294466257095337
40 o77 2.305081844329834
40 8Fg 2.2960262298583984
50 8Fg 2.301462411880493
50 o77 2.300893783569336
60 o77 2.305816173553467
60 8Fg 2.3013172149658203
70 o77 2.30063796043396
70 8Fg 2.3092100620269775
80 o77 2.3058948516845703
80 8Fg 2.3066720962524414
90 o77 2.3082518577575684
90 8Fg 2.302436113357544
100 8Fg 2.2975590229034424
100 o77 2.299251079559326
110 o77 2.303760051727295
110 8Fg 2.298640727996826
120 o77 2.3013532161712646
120 8Fg 2.2970705032348633
