3232from functools import wraps
3333
3434import os
35- import sys
35+ import abc
3636import shutil
3737import tempfile
3838import warnings
3939
40+ from six import add_metaclass , PY2
41+ from six .moves .urllib .request import urlopen
42+
4043import pytest
4144import numpy as np
4245
43- if sys .version_info [0 ] == 2 :
44- from urllib import urlopen
46+
47+ if PY2 :
48+ def abstractstaticmethod (func ):
49+ return func
50+ def abstractclassmethod (func ):
51+ return func
4552else :
46- from urllib .request import urlopen
53+ abstractstaticmethod = abc .abstractstaticmethod
54+ abstractclassmethod = abc .abstractclassmethod
55+
56+
57+ @add_metaclass (abc .ABCMeta )
58+ class BaseDiff (object ):
59+
60+ @abstractstaticmethod
61+ def read (filename ):
62+ """
63+ Given a filename, return a data object.
64+ """
65+ raise NotImplementedError ()
66+
67+ @abstractstaticmethod
68+ def write (filename , data , ** kwargs ):
69+ """
70+ Given a filename and a data object (and optional keyword arguments),
71+ write the data to a file.
72+ """
73+ raise NotImplementedError ()
74+
75+ @abstractclassmethod
76+ def compare (self , reference_file , test_file , atol = None , rtol = None ):
77+ """
78+ Given a reference and test filename, compare the data to the specified
79+ absolute (``atol``) and relative (``rtol``) tolerances.
80+
81+ Should return two arguments: a boolean indicating whether the data are
82+ identical, and a string giving the full error message if not.
83+ """
84+ raise NotImplementedError ()
85+
86+
87+ class SimpleArrayDiff (BaseDiff ):
88+
89+ @classmethod
90+ def compare (cls , reference_file , test_file , atol = None , rtol = None ):
91+
92+ array_ref = cls .read (reference_file )
93+ array_new = cls .read (test_file )
94+
95+ try :
96+ np .testing .assert_allclose (array_ref , array_new , atol = atol , rtol = rtol )
97+ except AssertionError as exc :
98+ message = "\n \n a: {0}" .format (test_file ) + '\n '
99+ message += "b: {0}" .format (reference_file ) + '\n '
100+ message += exc .args [0 ]
101+ return False , message
102+ else :
103+ return True , ""
47104
48105
49- class FITSDiff (object ):
106+ class FITSDiff (BaseDiff ):
50107
51108 extension = 'fits'
52109
@@ -56,12 +113,20 @@ def read(filename):
56113 return fits .getdata (filename )
57114
58115 @staticmethod
59- def write (filename , array , ** kwargs ):
116+ def write (filename , data , ** kwargs ):
60117 from astropy .io import fits
61- return fits .writeto (filename , array , ** kwargs )
118+ if isinstance (data , np .ndarray ):
119+ data = fits .PrimaryHDU (data )
120+ return data .writeto (filename , ** kwargs )
121+
122+ @classmethod
123+ def compare (cls , reference_file , test_file , atol = None , rtol = None ):
124+ from astropy .io .fits .diff import FITSDiff
125+ diff = FITSDiff (reference_file , test_file , tolerance = rtol )
126+ return diff .identical , diff .report ()
62127
63128
64- class TextDiff (object ):
129+ class TextDiff (SimpleArrayDiff ):
65130
66131 extension = 'txt'
67132
@@ -70,10 +135,10 @@ def read(filename):
70135 return np .loadtxt (filename )
71136
72137 @staticmethod
73- def write (filename , array , ** kwargs ):
138+ def write (filename , data , ** kwargs ):
74139 if 'fmt' not in kwargs :
75140 kwargs ['fmt' ] = '%g'
76- return np .savetxt (filename , array , ** kwargs )
141+ return np .savetxt (filename , data , ** kwargs )
77142
78143
79144FORMATS = {}
@@ -219,17 +284,12 @@ def item_function_wrapper(*args, **kwargs):
219284 baseline_file = os .path .abspath (os .path .join (result_dir , 'reference-' + filename ))
220285 shutil .copyfile (baseline_file_ref , baseline_file )
221286
222- array_ref = FORMATS [file_format ].read (baseline_file )
287+ identical , msg = FORMATS [file_format ].compare (baseline_file , test_image , atol = atol , rtol = rtol )
223288
224- try :
225- np .testing .assert_allclose (array_ref , array , atol = atol , rtol = rtol )
226- except AssertionError as exc :
227- message = "\n \n a: {0}" .format (test_image ) + '\n '
228- message += "b: {0}" .format (baseline_file ) + '\n '
229- message += exc .args [0 ]
230- raise AssertionError (message )
231-
232- shutil .rmtree (result_dir )
289+ if identical :
290+ shutil .rmtree (result_dir )
291+ else :
292+ raise Exception (msg )
233293
234294 else :
235295
0 commit comments