Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/test_fastai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import fastai
from fastai.tabular.all import *

from common import p100_exempt


class TestFastAI(unittest.TestCase):
# Basic import
Expand All @@ -22,6 +24,7 @@ def test_torch_tensor(self):

self.assertTrue(torch.all(a == b))

@p100_exempt
def test_tabular(self):
dls = TabularDataLoaders.from_csv(
"/input/tests/data/train.csv",
Expand Down
5 changes: 4 additions & 1 deletion tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as tnn
import torch.autograd as autograd

from common import gpu_test
from common import gpu_test, p100_exempt


class TestPyTorch(unittest.TestCase):
Expand All @@ -16,6 +16,7 @@ def test_nn(self):
linear_torch(data_torch)

@gpu_test
@p100_exempt
def test_linalg(self):
A = torch.randn(3, 3).t().to('cuda')
B = torch.randn(3).t().to('cuda')
Expand All @@ -24,6 +25,7 @@ def test_linalg(self):
self.assertEqual(3, result.shape[0])

@gpu_test
@p100_exempt
def test_gpu_computation(self):
cuda = torch.device('cuda')
a = torch.tensor([1., 2.], device=cuda)
Expand All @@ -33,6 +35,7 @@ def test_gpu_computation(self):
self.assertEqual(torch.tensor([3.], device=cuda), result)

@gpu_test
@p100_exempt
def test_cuda_nn(self):
# These throw if cuda is misconfigured
tnn.GRUCell(10,10).cuda()
Expand Down
3 changes: 3 additions & 0 deletions tests/test_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from common import p100_exempt


class LitDataModule(pl.LightningDataModule):

Expand Down Expand Up @@ -59,6 +61,7 @@ class TestPytorchLightning(unittest.TestCase):
def test_version(self):
self.assertIsNotNone(pl.__version__)

@p100_exempt
def test_mnist(self):
dm = LitDataModule()
model = LitClassifier()
Expand Down
Loading