In [1]:
import torch.distributed.rpc as rpc
import torch
import os
import torch.optim as optim
import torch
from torch.multiprocessing import Process
from torchvision import datasets, transforms
from train import *
from dataload import *
from model import *


In [2]:
def run(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    model = CNNModel(args['host'], args['worker'],device)
    
    # setup distributed optimizer
    opt = DistributedOptimizer(
        optim.Adam,
        model.parameter_rrefs(),
        lr=args['lr'],
    )

    train_loader = MNISTDataLoader(args['root'], args['batch_size'], train=True)
    test_loader = MNISTDataLoader(args['root'], args['batch_size'], train=False)

    trainer = Trainer(model, opt, train_loader, test_loader, device)
    trainer.fit(args['epochs'])

In [3]:
def main():
    argv = {'world_size': int(2),
            'rank': int(0),
            'host': "worker0",
            'worker': "worker1",
            'epochs': int(5),
            'lr': float(1e-3),
            'root': 'data',
            'batch_size': int(32)
           }
    
    print(argv)
    rpc.init_rpc(argv['host'], rank=argv['rank'], world_size=argv['world_size'])
    print('Start Run', argv['rank'])
    run(argv)
    rpc.shutdown()

In [None]:
os.environ['MASTER_ADDR'] = '10.1.1.101'
os.environ['MASTER_PORT'] = '29505'
main()

{'world_size': 2, 'rank': 0, 'host': 'worker0', 'worker': 'worker1', 'epochs': 10, 'lr': 0.001, 'root': 'data', 'batch_size': 32}
Start Run 0
cuda
ConvNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 12, kernel_size=(5, 5), stride=(1, 1))
)
FCNet(
  (fc1): Linear(in_features=192, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=60, bias=True)
  (out): Linear(in_features=60, out_features=10, bias=True)
)
CNN model constructed: owner
Epoch: 1/10, train loss: 0.561120, train acc: 78.93%,
Epoch: 2/10, train loss: 0.372587, train acc: 86.53%,
Epoch: 3/10, train loss: 0.325346, train acc: 87.97%,
Epoch: 4/10, train loss: 0.298535, train acc: 88.95%,
Epoch: 5/10, train loss: 0.279765, train acc: 89.67%,
Epoch: 6/10, train loss: 0.262671, train acc: 90.21%,
Epoch: 7/10, train loss: 0.250404, train acc: 90.72%,
Epoch: 8/10, train loss: 0.239128, train acc: 91.16%,
Epoch: 9/10, train loss: 0.227272, train acc: 91.44%,
