In [None]:
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
from torchvision import datasets, transforms
from train import *
from dataload import *
from model import *

In [None]:
def distributed_is_initialized():
    if dist.is_available():
        if dist.is_initialized():
            return True
    return False

In [None]:
def run(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    model = Net()
    is_distributed = distributed_is_initialized()
    print("is_distributed:", is_distributed)
    if is_distributed:
        model.to(device)
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        #model = nn.DataParallel(model)
        model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])

    train_loader = MNISTDataLoader(args['root'], args['batch_size'], train=True, distributed=is_distributed)
    test_loader = MNISTDataLoader(args['root'], args['batch_size'], train=False, distributed=is_distributed)

    trainer = Trainer(model, optimizer, train_loader, test_loader, device)
    trainer.fit(args['epochs'])

In [None]:
def main():
    argv = {'world_size': int(2),
            'rank': int(0),
            'epochs': int(10),
            'back_end': 'nccl',
            'init_method': 'tcp://10.1.1.101:23456',
            'lr': float(1e-3),
            'root': 'data',
            'batch_size': int(32)
           }
    
    print(argv)
    if argv['world_size'] > 1:
        dist.init_process_group(
            backend=argv['back_end'],
            init_method=argv['init_method'],
            world_size=argv['world_size'],
            rank=argv['rank'],
    )
    print('Start Run')
    run(argv)

In [None]:
main()
