Skip to content

Commit

Permalink
[tests/optims] Add distributed sgd and adamw test
Browse files Browse the repository at this point in the history
  • Loading branch information
iory committed Apr 12, 2019
1 parent 4e8ed9a commit 3f697ec
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 0 deletions.
Empty file added tests/optims/__init__.py
Empty file.
45 changes: 45 additions & 0 deletions tests/optims/distributed_adamw_test.py
@@ -0,0 +1,45 @@
import os
import unittest

import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import torch.nn as nn

from machina.optims import DistributedAdamW


def init_processes(rank, world_size,
function, backend='tcp'):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank,
world_size=world_size)
function(rank, world_size)


class TestDistributedAdamW(unittest.TestCase):

def test_step(self):

def _run(rank, world_size):
model = nn.Linear(10, 1)
optimizer = DistributedAdamW(
model.parameters())

optimizer.zero_grad()
loss = model(torch.ones(10).float())
loss.backward()
optimizer.step()

processes = []
world_size = 4
for rank in range(world_size):
p = Process(target=init_processes,
args=(rank,
world_size,
_run))
p.start()
processes.append(p)
for p in processes:
p.join()
45 changes: 45 additions & 0 deletions tests/optims/distributed_sgd_test.py
@@ -0,0 +1,45 @@
import os
import unittest

import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import torch.nn as nn

from machina.optims import DistributedSGD


def init_processes(rank, world_size,
function, backend='tcp'):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank,
world_size=world_size)
function(rank, world_size)


class TestDistributedSGD(unittest.TestCase):

def test_step(self):

def _run(rank, world_size):
model = nn.Linear(10, 1)
optimizer = DistributedSGD(
model.parameters())

optimizer.zero_grad()
loss = model(torch.ones(10).float())
loss.backward()
optimizer.step()

processes = []
world_size = 4
for rank in range(world_size):
p = Process(target=init_processes,
args=(rank,
world_size,
_run))
p.start()
processes.append(p)
for p in processes:
p.join()

0 comments on commit 3f697ec

Please sign in to comment.