Skip to content

Commit

Permalink
Add assertRaises to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Tony-Y committed May 11, 2024
1 parent b7576b2 commit e94a745
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 9 deletions.
26 changes: 19 additions & 7 deletions pytorch_warmup/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
from torch.optim import Optimizer


def _check_optimizer(optimizer):
if not isinstance(optimizer, Optimizer):
raise TypeError('{} ({}) is not an Optimizer.'.format(
optimizer, type(optimizer).__name__))


class BaseWarmup(object):
"""Base class for all warmup schedules
Expand All @@ -13,9 +19,6 @@ class BaseWarmup(object):
"""

def __init__(self, optimizer, warmup_params, last_step=-1):
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer
self.warmup_params = warmup_params
self.last_step = last_step
Expand Down Expand Up @@ -69,19 +72,26 @@ def get_warmup_params(warmup_period, group_count):
if isinstance(warmup_period, list):
if len(warmup_period) != group_count:
raise ValueError(
'size of warmup_period does not equal {}.'.format(group_count))
'The size of warmup_period ({}) does not match the size of param_groups ({}).'.format(
len(warmup_period), group_count))
for x in warmup_period:
if not isinstance(x, int):
raise ValueError(
raise TypeError(
'An element in warmup_period, {}, is not an int.'.format(
type(x).__name__))
if x <= 0:
raise ValueError(
'An element in warmup_period must be a positive integer, but is {}.'.format(x))
warmup_params = [dict(warmup_period=x) for x in warmup_period]
elif isinstance(warmup_period, int):
if warmup_period <= 0:
raise ValueError(
'warmup_period must be a positive integer, but is {}.'.format(warmup_period))
warmup_params = [dict(warmup_period=warmup_period)
for _ in range(group_count)]
else:
raise TypeError('{} is not a list nor an int.'.format(
type(warmup_period).__name__))
raise TypeError('{} ({}) is not a list nor an int.'.format(
warmup_period, type(warmup_period).__name__))
return warmup_params


Expand All @@ -95,6 +105,7 @@ class LinearWarmup(BaseWarmup):
"""

def __init__(self, optimizer, warmup_period, last_step=-1):
_check_optimizer(optimizer)
group_count = len(optimizer.param_groups)
warmup_params = get_warmup_params(warmup_period, group_count)
super(LinearWarmup, self).__init__(optimizer, warmup_params, last_step)
Expand All @@ -113,6 +124,7 @@ class ExponentialWarmup(BaseWarmup):
"""

def __init__(self, optimizer, warmup_period, last_step=-1):
_check_optimizer(optimizer)
group_count = len(optimizer.param_groups)
warmup_params = get_warmup_params(warmup_period, group_count)
super(ExponentialWarmup, self).__init__(optimizer, warmup_params, last_step)
Expand Down
3 changes: 2 additions & 1 deletion pytorch_warmup/radam.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from .base import BaseWarmup
from .base import BaseWarmup, _check_optimizer


def rho_inf_fn(beta2):
Expand Down Expand Up @@ -35,6 +35,7 @@ class RAdamWarmup(BaseWarmup):
"""

def __init__(self, optimizer, last_step=-1):
_check_optimizer(optimizer)
warmup_params = [
dict(
beta2=x['betas'][1],
Expand Down
4 changes: 3 additions & 1 deletion pytorch_warmup/untuned.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import LinearWarmup, ExponentialWarmup
from .base import LinearWarmup, ExponentialWarmup, _check_optimizer


class UntunedLinearWarmup(LinearWarmup):
Expand All @@ -14,6 +14,7 @@ class UntunedLinearWarmup(LinearWarmup):
"""

def __init__(self, optimizer, last_step=-1):
_check_optimizer(optimizer)
def warmup_period_fn(beta2):
return int(2.0 / (1.0-beta2))
warmup_period = [warmup_period_fn(x['betas'][1]) for x in optimizer.param_groups]
Expand All @@ -33,6 +34,7 @@ class UntunedExponentialWarmup(ExponentialWarmup):
"""

def __init__(self, optimizer, last_step=-1):
_check_optimizer(optimizer)
def warmup_period_fn(beta2):
return int(1.0 / (1.0-beta2))
warmup_period = [warmup_period_fn(x['betas'][1]) for x in optimizer.param_groups]
Expand Down
36 changes: 36 additions & 0 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,34 @@ def _test_state_dict(self, warmup_scheduler, constructor):
warmup_scheduler_copy.__dict__[key])


def _test_optimizer(self, warmup_class):
with self.assertRaises(TypeError, msg='optimizer type') as cm:
warmup_class(optimizer=0, warmup_period=5)
self.assertEqual(str(cm.exception), '0 (int) is not an Optimizer.')


def _test_get_warmup_params(self, optimizer, warmup_class):
with self.assertRaises(ValueError, msg='warmup_period size') as cm:
warmup_class(optimizer, warmup_period=[5])
self.assertEqual(str(cm.exception), 'The size of warmup_period (1) does not match the size of param_groups (2).')

with self.assertRaises(TypeError, msg='warmup_period element type') as cm:
warmup_class(optimizer, warmup_period=[5.0, 10.0])
self.assertEqual(str(cm.exception), 'An element in warmup_period, float, is not an int.')

with self.assertRaises(ValueError, msg='warmup_period element range') as cm:
warmup_class(optimizer, warmup_period=[5, 0])
self.assertEqual(str(cm.exception), 'An element in warmup_period must be a positive integer, but is 0.')

with self.assertRaises(ValueError, msg='warmup_period range') as cm:
warmup_class(optimizer, warmup_period=0)
self.assertEqual(str(cm.exception), 'warmup_period must be a positive integer, but is 0.')

with self.assertRaises(TypeError, msg='warmup_period type') as cm:
warmup_class(optimizer, warmup_period=5.0)
self.assertEqual(str(cm.exception), '5.0 (float) is not a list nor an int.')


class TestBase(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -45,6 +73,10 @@ def test_linear(self):
_test_state_dict(self, warmup_scheduler,
lambda: warmup.LinearWarmup(optimizer, warmup_period=10))

_test_optimizer(self, warmup.LinearWarmup)

_test_get_warmup_params(self, optimizer, warmup.LinearWarmup)

def test_exponetial(self):
p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
Expand All @@ -67,6 +99,10 @@ def test_exponetial(self):
_test_state_dict(self, warmup_scheduler,
lambda: warmup.ExponentialWarmup(optimizer, warmup_period=10))

_test_optimizer(self, warmup.ExponentialWarmup)

_test_get_warmup_params(self, optimizer, warmup.ExponentialWarmup)

def test_linear_chaining(self):
def preparation():
p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
Expand Down
3 changes: 3 additions & 0 deletions test/test_radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytorch_warmup as warmup

from .test_base import _test_state_dict
from .test_untuned import _test_optimizer


class TestRAdam(unittest.TestCase):
Expand Down Expand Up @@ -31,3 +32,5 @@ def test_radam(self):

_test_state_dict(self, warmup_scheduler,
lambda: warmup.RAdamWarmup(optimizer))

_test_optimizer(self, warmup.RAdamWarmup)
10 changes: 10 additions & 0 deletions test/test_untuned.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from .test_base import _test_state_dict


def _test_optimizer(self, warmup_class):
with self.assertRaises(TypeError, msg='optimizer type') as cm:
warmup_class(optimizer=0)
self.assertEqual(str(cm.exception), '0 (int) is not an Optimizer.')


class TestUntuned(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -37,6 +43,8 @@ def test_untuned_linear(self):
_test_state_dict(self, warmup_scheduler,
lambda: warmup.UntunedLinearWarmup(optimizer))

_test_optimizer(self, warmup.UntunedLinearWarmup)

def test_untuned_exponetial(self):
p1 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
p2 = torch.nn.Parameter(torch.arange(10, dtype=torch.float32).to(self.device))
Expand All @@ -58,3 +66,5 @@ def test_untuned_exponetial(self):

_test_state_dict(self, warmup_scheduler,
lambda: warmup.UntunedExponentialWarmup(optimizer))

_test_optimizer(self, warmup.UntunedExponentialWarmup)

0 comments on commit e94a745

Please sign in to comment.