Skip to content
Permalink
Browse files

[tests/optims] Add distributed sgd and adamw test

  • Loading branch information...
iory committed Apr 12, 2019
1 parent 4e8ed9a commit 3f697ec23e74d196bfe5a0601d489124bad82240
No changes.
@@ -0,0 +1,45 @@
import os
import unittest

import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import torch.nn as nn

from machina.optims import DistributedAdamW


def init_processes(rank, world_size,
function, backend='tcp'):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank,
world_size=world_size)
function(rank, world_size)


class TestDistributedAdamW(unittest.TestCase):

def test_step(self):

def _run(rank, world_size):
model = nn.Linear(10, 1)
optimizer = DistributedAdamW(
model.parameters())

optimizer.zero_grad()
loss = model(torch.ones(10).float())
loss.backward()
optimizer.step()

processes = []
world_size = 4
for rank in range(world_size):
p = Process(target=init_processes,
args=(rank,
world_size,
_run))
p.start()
processes.append(p)
for p in processes:
p.join()
@@ -0,0 +1,45 @@
import os
import unittest

import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import torch.nn as nn

from machina.optims import DistributedSGD


def init_processes(rank, world_size,
function, backend='tcp'):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank,
world_size=world_size)
function(rank, world_size)


class TestDistributedSGD(unittest.TestCase):

def test_step(self):

def _run(rank, world_size):
model = nn.Linear(10, 1)
optimizer = DistributedSGD(
model.parameters())

optimizer.zero_grad()
loss = model(torch.ones(10).float())
loss.backward()
optimizer.step()

processes = []
world_size = 4
for rank in range(world_size):
p = Process(target=init_processes,
args=(rank,
world_size,
_run))
p.start()
processes.append(p)
for p in processes:
p.join()

0 comments on commit 3f697ec

Please sign in to comment.
You can’t perform that action at this time.