Skip to content

Unit Testing for pytorch, based on mltest

License

Notifications You must be signed in to change notification settings

BrianPugh/torchtest

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

31 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

torchtest

A Tiny Test Suite for pytorch based Machine Learning models, inspired by mltest. Chase Roberts lists out 4 basic tests in his medium post about mltest. torchtest is mostly a pytorch port of mltest(which was written for tensorflow).

Installation

pip install --upgrade torchtest

Tests

# imports for examples
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

Variables Change

from torchtest import assert_vars_change

inputs = Variable(torch.randn(20, 20))
targets = Variable(torch.randint(0, 2, (20,))).long()
batch = [inputs, targets]
model = nn.Linear(20, 2)

# what are the variables?
print('Our list of parameters', [ np[0] for np in model.named_parameters() ])

# do they change after a training step?
#  let's run a train step and see
assert_vars_change(
    model=model,
    loss_fn=F.cross_entropy,
    optim=torch.optim.Adam(model.parameters()),
    batch=batch)
""" FAILURE """
# let's try to break this, so the test fails
params_to_train = [ np[1] for np in model.named_parameters() if np[0] is not 'bias' ]
# run test now
assert_vars_change(
    model=model,
    loss_fn=F.cross_entropy,
    optim=torch.optim.Adam(params_to_train),
    batch=batch)

# YES! bias did not change

Variables Don’t Change

from torchtest import assert_vars_same

# What if bias is not supposed to change, by design?
#  test to see if bias remains the same after training
assert_vars_same(
    model=model,
    loss_fn=F.cross_entropy,
    optim=torch.optim.Adam(params_to_train),
    batch=batch,
    params=[('bias', model.bias)]
    )
# it does? good. let's move on

Output Range

from torchtest import test_suite

# NOTE : bias is fixed (not trainable)
optim = torch.optim.Adam(params_to_train)
loss_fn=F.cross_entropy

test_suite(model, loss_fn, optim, batch,
    output_range=(-2, 2),
    test_output_range=True
    )

# seems to work
""" FAILURE """
#  let's tweak the model to fail the test
model.bias = nn.Parameter(2 + torch.randn(2, ))

test_suite(
    model,
    loss_fn, optim, batch,
    output_range=(-1, 1),
    test_output_range=True
    )

# as expected, it fails; yay!

NaN Tensors

""" FAILURE """
model.bias = nn.Parameter(float('NaN') * torch.randn(2, ))

test_suite(
    model,
    loss_fn, optim, batch,
    test_nan_vals=True
    )

Inf Tensors

""" FAILURE """
model.bias = nn.Parameter(float('Inf') * torch.randn(2, ))

test_suite(
    model,
    loss_fn, optim, batch,
    test_inf_vals=True
    )

Debugging

torchtest\torchtest.py", line 151, in _var_change_helper
assert not torch.equal(p0, p1)
RuntimeError: Expected object of backend CPU but got backend CUDA for argument #2 'other'

When you are making use of a GPU, you should explicitly specify device=cuda:0. By default device is set to cpu. See issue #1 for more information.

test_suite(
    model,  # a model moved to GPU
    loss_fn, optim, batch,
    test_inf_vals=True,
    device='cuda:0'
    )

Citation

@misc{Ram2019,
  author = {Suriyadeepan Ramamoorthy},
  title = {torchtest},
  year = {2019},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/suriyadeepan/torchtest}},
  commit = {42ba442e54e5117de80f761a796fba3589f9b223}
}

About

Unit Testing for pytorch, based on mltest

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%