In [1]:
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
from torchvision import datasets, transforms
from train import *
from dataload import *
from model import *

In [2]:
def distributed_is_initialized():
    if dist.is_available():
        if dist.is_initialized():
            return True
    return False

In [3]:
def run(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    model = Net()
    is_distributed = distributed_is_initialized()
    print("is_distributed:", is_distributed)
    if is_distributed:
        model.to(device)
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        #model = nn.DataParallel(model)
        model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])

    train_loader = MNISTDataLoader(args['root'], args['batch_size'], train=True, distributed=is_distributed)
    test_loader = MNISTDataLoader(args['root'], args['batch_size'], train=False, distributed=is_distributed)

    trainer = Trainer(model, optimizer, train_loader, test_loader, device)
    trainer.fit(args['epochs'])

In [4]:
def main():
    argv = {'world_size': int(2),
            'rank': int(0),
            'epochs': int(10),
            'back_end': 'nccl',
            'init_method': 'tcp://10.1.1.101:23456',
            'lr': float(1e-3),
            'root': 'data',
            'batch_size': int(32)
           }
    
    print(argv)
    if argv['world_size'] > 1:
        dist.init_process_group(
            backend=argv['back_end'],
            init_method=argv['init_method'],
            world_size=argv['world_size'],
            rank=argv['rank'],
    )
    print('Start Run')
    run(argv)

In [5]:
main()


{'world_size': 2, 'rank': 0, 'epochs': 10, 'back_end': 'nccl', 'init_method': 'tcp://10.1.1.101:23456', 'lr': 0.001, 'root': 'data', 'batch_size': 32}
Start Run
cuda
is_distributed: True
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Processing...
Done!







Epoch: 1/10, train loss: 0.655683, train acc: 75.53%, test loss: 0.471824, test acc: 82.65%.
Epoch: 2/10, train loss: 0.414107, train acc: 84.78%, test loss: 0.397921, test acc: 85.51%.
Epoch: 3/10, train loss: 0.354224, train acc: 86.89%, test loss: 0.365463, test acc: 86.78%.
Epoch: 4/10, train loss: 0.323007, train acc: 87.94%, test loss: 0.348732, test acc: 87.32%.
Epoch: 5/10, train loss: 0.301348, train acc: 88.89%, test loss: 0.336323, test acc: 87.71%.
Epoch: 6/10, train loss: 0.285019, train acc: 89.39%, test loss: 0.333583, test acc: 87.91%.
Epoch: 7/10, train loss: 0.270017, train acc: 89.93%, test loss: 0.321420, test acc: 88.59%.
Epoch: 8/10, train loss: 0.256941, train acc: 90.43%, test loss: 0.312369, test acc: 88.88%.
Epoch: 9/10, train loss: 0.246305, train acc: 90.78%, test loss: 0.309474, test acc: 89.20%.
Epoch: 10/10, train loss: 0.235589, train acc: 91.23%, test loss: 0.306895, test acc: 89.54%.
