We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
此任务的开始依赖PR 5520合并进Master。PR 5520已合并到Master。 背景 之前在issue 5352中的自动测试重构计划是基于老版的自动测试框架,例如无法控制Module传入的参数列表,无法控制只测试有前向的API等等,有较多无法解决的行为并且写法也不自然。新的自动测试框架改善了这些缺点,并提供了一个和原始Pytorch调用API类似的单测写法。所以,我们期望将之前任务分工里面所有API的单测重构迁移到新的自动测试框架中。对于每个API,无论是否迁移过旧版的自动测试方法,均需要用新版的自动测试方法来重写以统一标准。全部迁移完成后,旧版自动测试的接口将删除。 项目负责人 张晓雨,梁德澎,姚迟,赵露阳,郑泽康,任天和,李春游,刘沛宏,张申,也欢迎之江小伙伴们认领 任务分工表 领任务的时候Module如果有Tensor方法需要一起领取。 有几个Loss相关的单测,建议谁开发的谁来修改,应该Loss的测试比较复杂,如果测出了BUG也好自行解决。 gt, le, ne 等比较运算符Module没有加广播参数测试,注意加上。 使用方法 举例 对于nn.Module,以卷积为例展示用法: @unittest.skip("need a more relaxed tolerance") @autotest() def test_against_pytorch(test_case): channels = random(1, 6) m = torch.nn.Conv2d(channels, random(1, 6), random(1, 6), stride=random(1, 3) | nothing(), padding=random(1, 3) | nothing(), dilation=random(1, 3) | nothing(), groups=random(1, 3) | nothing(), bias=random() | nothing(), padding_mode=constant('zeros') | nothing() ) m.train(random()) device = random_device() m.to(device) x = random_pytorch_tensor(ndim=4, dim1=channels, dim2=random(1, 8), dim3=random(1, 8)).to(device) y = m(x) return y 可以看到新版的自动测试写法和原始的Pytorch构建一个Module是非常类似的,首先声明Conv2d Module记作m,然后调用m.train(random())设置Module是train还是eval模式,最后构造输入然后进行预测,返回输出Tensor。 使用的时候需要注意的是,Module里面你想测试的参数都需要自行指定,比如Conv2d里面的padding指定为random(1, 3) | nothing()表示padding获得的值有可能是自动测试框架随机生成的(对应random(1, 3)),也有可能是Pytorch或者OneFlow框架给出的参数默认值(对应nothing)。 对于flow.xxx方法,以matmul为例: @autotest() def test_flow_matmul_with_random_data(test_case): k = random(1, 6) x = random_pytorch_tensor(ndim=2, dim1=k) y = random_pytorch_tensor(ndim=2, dim0=k) z = torch.matmul(x, y) return z 对于flow.matmul的测试,我们基于random_pytorch_tensor方法构造了两个随机Tensor x和y,它们的维度分别是[m, k]和[k, n],这些维度的值都是随机生成的。 对于Tensor.xxx的测试方法,以Tanh为例: @autotest() def test_tensor_tan(test_case): x = random_pytorch_tensor().to(random_device()) return x.tan() 和flow.xxx方法的测试类似,这里就不重复了。 小结 新版的自动测试框架最大的好处就是我们可以像写Pytorch代码那样去写一个测试样例,并且我们可以随意组合参数输入数据的genator来产生各种不同的数据。另外@autotest()这个装饰器还有一些可选参数让我们灵活控制测试,这里看一下定义def autotest(n=20, auto_backward=True, rtol=1e-4, atol=1e-5)。例如对于greater这个函数,它不支持反向,那么我们就设置auto_backward=False,如果某个Module误差会比1e-4还大,那么我们可以调整rtol,再比如我们想控制随机测试的有效次数,则可以通过设置n来解决。 实现思路(进阶) 了解了使用方法之后,这一节我们来了解一下新版自动测试框架的实现思路。我主要从两个方面来讲,首先是讲解自动测试框架中的generators,然后讲解自动测试框架中和Pytorch交互部分的核心实现。 generators 实现在:oneflow/python/test/modules/automated_test_util/generators.py。首先我们从最后几行可以了解到这个文件主要导出了一些generator方法: __all__ = [ "random_tensor", "random_bool", "random_device", "random", "random_or_nothing", "constant", "nothing", "test_module_against_pytorch", "test_flow_against_pytorch", "test_tensor_against_pytorch", ] 在这个文件中,从def test_against_pytorch这行代码开始到结束的代码在新版的测试框架中已经没用了,只是为了兼容前一个版本大家写的自动测试方法的PR,在大家把之前的自动测试方法迁移到新版后会移除这些代码。这样的话,这个文件就剩下了一些generator的实现。random,random_bool ,constant所有这些generator都继承了generator基类,基类的定义如下: class generator: def __init__(self, children): self.children = children self._value = None def _init(self): self._value = None for x in self.children: x._init() def eval(self): self._init() return self.value() def _calc_value(self): raise NotImplementedError() def value(self): if self._value is None: self._value = self._calc_value() return self._value def size(self): return 1 def __or__(self, other): other = pack(other) return oneof( self, other, possibility=self.size() / (self.size() + other.size()) ) def __ror__(self, other): return self | other def __add__(self, other): return add(self, other) def __radd__(self, other): return self + other def __sub__(self, other): return self + neg(other) def __rsub__(self, other): return neg(self - other) def to(self, annotation): self._to(annotation) for x in self.children: x.to(annotation) return self def _to(self, annotation): pass 所有的generator子类都继承了这个基类,并重写其中的__init__和__calc_value,size等成员函数。例如Nothing就是直接在_calc_value里面返回一个空的class, 实现如下: class Nothing: pass class nothing(generator): def __init__(self): super().__init__([]) def _calc_value(self): return Nothing() 再例如,random这个子类定义如下: class random(generator): def __init__(self, low=1, high=6): self.low = pack(low) self.high = pack(high) super().__init__([self.low, self.high]) self.annotation = None def _to(self, annotation): if self.annotation is not None: return if hasattr(annotation, "__origin__"): # PyTorch _size_2_t and similar types are defined by type variables, # leading to unexpected __args__ and __origin__ # # >>> _size_2_t = Union[T, Tuple[T, T]][int] # >>> _size_2_t.__origin__ # typing.Union[~T, typing.Tuple[~T, ~T]] # # So recreate a new annotation object by repr and eval # # >>> _size_2_t # typing.Union[int, typing.Tuple[int, int]] # >>> _size_2_t_new = eval(repr(annotation)) # >>> _size_2_t_new.__origin__ # typing.Union annotation = eval(repr(annotation)) self.annotation = annotation def _generate(self, annotation): if hasattr(annotation, "__origin__"): if annotation.__origin__ is Union: x = random_util.choice(annotation.__args__) return self._generate(x) if annotation.__origin__ is Tuple or annotation.__origin__ is py_tuple: return [self._generate(x) for x in annotation.__args__] else: raise NotImplementedError( f"Not implemented annotation {annotation} in random, type(annotation.__origin__) is {type(annotation.__origin__)}" ) low, high = self.low.value(), self.high.value() if annotation == int: val = int(rng.integers(low, high)) elif annotation == float: val = float(rng.random() * (high - low) + low) elif annotation == bool: val = random_util.choice([True, False]) else: raise NotImplementedError( f"Not implemented annotation {annotation} in random" ) return val def _calc_value(self): return self._generate(self.annotation) def random_or_nothing(low, high): return oneof(random(low, high), nothing(), possibility=2 / 3) 这里需要注意的是annoation是在下一个阶段传进来的参数的Type,比如int, float, bool,__size_2_t等等。获取了annoation的值这个generator才可以产生真正合法的数据。 torch_flow_dual_object 实现思路 这是新版测试框架的核心实现,代码在:oneflow/python/test/modules/automated_test_util/torch_flow_dual_object.py 。我们依然从这个文件导出了什么开始:__all__ = ["torch", "autotest", "random_pytorch_tensor"] 首先导出了torch,这个torch可以理解为是原始Pytorch的更高层封装,这个封装体现在参数的输入数据可以用上一节的generators来进行组合。导出high level 的torch的代码如下:torch = GetDualObject("", torch_original, flow)。 所以核心实现是GetDualObject这个函数,我们来看一下这个函数做了什么? class DualObject: def __init__(self, name, pytorch, oneflow): self.name = name self.pytorch = pytorch self.oneflow = oneflow if isinstance(pytorch, torch_original.nn.Module): state_dict = pytorch.state_dict() state_dict = {k: v.detach().cpu().numpy() for k, v in state_dict.items()} oneflow.load_state_dict(state_dict) dual_modules_to_test.append(self) if isinstance(pytorch, torch_original.Tensor): dual_objects_to_test.append(self) def __repr__(self): return f"PyTorch object:\n{self.pytorch}\n\nOneFlow object:\n{self.oneflow}" def __getattr__(self, key): pytorch_attr = getattr(self.pytorch, key) oneflow_attr = getattr(self.oneflow, key) new_name = f"{self.name}.{key}" return GetDualObject(new_name, pytorch_attr, oneflow_attr) 在初始化里面首先传了两个Python对象,分别是Pytorch和OneFlow,在导出high level的torch的时候传的是:torch_original和flow,而导出random_pytorch_tensor的时候传的是:pytorch_tensor和flow_tensor。这里不妨先列出random_pytorch_tensor这个函数的实现: def random_pytorch_tensor( ndim=None, dim0=1, dim1=None, dim2=None, dim3=None, dim4=None, requires_grad=True ): if isinstance(requires_grad, generator): requires_grad = requires_grad.value() pytorch_tensor = ( random_tensor(ndim, dim0, dim1, dim2, dim3, dim4) .value() .requires_grad_(requires_grad) ) flow_tensor = flow.tensor(pytorch_tensor.detach().cpu().numpy(), requires_grad=True) return GetDualObject("unused", pytorch_tensor, flow_tensor) 可以看到它和导出high level的torch是一样的实现思路, 仍然调用了GetDualObject这个类的构造函数。 继续回到DualObject的实现,我们发现这里分别使用了dual_modules_to_test和dual_objects_to_test这两个列表分别来记录OneFlow和Pytorch的Module和Tensor对象。并重写了__getattr__魔法方法,以Flatten为例子查看一下它具体做了什么? def __getattr__(self, key): pytorch_attr = getattr(self.pytorch, key) oneflow_attr = getattr(self.oneflow, key) print(key) # print(pytorch_attr) # print(oneflow_attr) new_name = f"{self.name}.{key}" return GetDualObject(new_name, pytorch_attr, oneflow_attr) # flatten的测试程序 @autotest(auto_backward=False) def test_against_pytorch(test_case): m = torch.nn.Flatten( start_dim=random(1, 6) | nothing(), end_dim=random(1, 6) | nothing() ) m.train(random()) device = random_device() m.to(device) x = random_pytorch_tensor().to(device) y = m(x) return y 然后看一下__getattr__中key的打印结果,重复打印和省略号是因为有20轮测试: nn Flatten train to to nn Flatten train to to nn Flatten train to ... 可以看到由@autotest这个装饰器包起来的程序的Pytorch Module或者函数都重写了这个方法,它将这些Module或者函数的参数取出来用GetDualObject返回一个新的DualObject对象。我们可以打印一下Flatten这个Module对应的DualObject对象是什么? PyTorch object: <bound method Module.train of Flatten(start_dim=1, end_dim=-1)> OneFlow object: <bound method Module.train of Flatten(start_dim=1, end_dim=-1)> GetDualObject这个函数就是根据传入的Pytorch以及OneFlow对象和它们的名字(两个类的名字必须是相同的,这样才是和Pytorch对齐)来生成一个DualObject对象。GetDualObject这个函数会为Pytorch重写传入的Pytorch以及OneFlow对象的魔法函数,返回一个DualObject对象,这个过程还包含了跳过一些不合法的魔法函数以及检查传入对象的属性是否合法。这里还有一句对于Tensor方法的特判,因为Tensor的API调用方式和其它Module和函数不同。 接下来,就是看一下autotest装饰器的实现了: def autotest(n=20, auto_backward=True, rtol=1e-4, atol=1e-5): verbose = os.getenv("ONEFLOW_TEST_VERBOSE") is not None def deco(f): @functools.wraps(f) def new_f(test_case): nonlocal n while n > 0: dual_modules_to_test.clear() dual_objects_to_test.clear() try: res = f(test_case) except PyTorchDoesNotSupportError as e: if verbose: print(e) continue # TODO: support types other than Tensor, like torch.Size/flow.Size if res is not None: if not isinstance(res, collections.abc.Sequence): res = [res] for x in res: if auto_backward: if isinstance(x.pytorch, torch_original.Tensor): x.sum().backward() dual_objects_to_test.append(x) for x in dual_modules_to_test: # x.state_dict().values() returns dual object with inconsistent values for key in x.pytorch.state_dict().keys(): dual_objects_to_test.append( GetDualObject( "unused", x.pytorch.state_dict()[key], x.oneflow.state_dict()[key], ) ) for x in dual_objects_to_test: test_case.assertTrue(check_equality(x)) if verbose: print("test passed") n -= 1 return new_f return deco 最后,这个装饰器把包起来的high level的程序执行一遍,获得每个中间Tensor,最后再对每个Tensor进行判断是否在合法的精度范围内。 所以核心的实现是上面产生DualObject的过程,它完成了high-level的Pytorch往原始的pytorch的oneflow的转换,让转换后的张量程序可以直接运行,以获取所有中间结果。 补充说明 在介绍generators的时候讲到各个generator 的annoations必须获取到具体参数类型才可以产生真正的合法数字,这个过程是在GetDualObject这个函数中的get_args函数完成的,里面的get_generator_value 会遍历所有的参数列表,将dtype传给这些列表中的generator生成真正的合法数据。 def get_generator_value(x): if isinstance(x, generator): return x.value() return x 分工 重构Module 认领人 reviewer PR 备注 oneflow.nn.ReLU 张晓雨 大缺弦 #5562 已完成 oneflow.nn.ReLU6 张晓雨 大缺弦 #5562 已完成 oneflow.nn.LeakyReLU 张晓雨 大缺弦 #5562 已完成 oneflow.nn.Tanh 张晓雨 大缺弦 #5562 已完成 oneflow.tanh 张晓雨 大缺弦 #5562 已完成 oneflow.Tensor.tanh 张晓雨 大缺弦 #5562 已完成 oneflow.asin 张晓雨 大缺弦 #5562 已完成 oneflow.Tensor.asin 张晓雨 大缺弦 #5562 已完成 oneflow.arcsin 张晓雨 大缺弦 #5562 已完成 oneflow.Tensor.arcsin 张晓雨 大缺弦 #5562 已完成 oneflow.asinh 张晓雨 大缺弦 #5562 已完成 oneflow.Tensor.asinh 张晓雨 大缺弦 #5562 已完成 oneflow.sinh 张晓雨 大缺弦 #5562 已完成 oneflow.Tensor.sinh 张晓雨 大缺弦 #5562 已完成 oneflow.atan2 张晓雨 大缺弦 #5562 已完成 oneflow.Tensor.atan2 张晓雨 大缺弦 #5562 已完成 oneflow.softplus 张晓雨 大缺弦 #5562 已完成 oneflow.arcsinh 张晓雨 大缺弦 #5562 已完成 oneflow.Tensor.arcsinh 张晓雨 大缺弦 #5562 已完成 oneflow.nn.ELU 张晓雨 大缺弦 #5562 已完成 oneflow.nn.GELU 刘沛鸿 BBuf #5646 已完成 oneflow.gelu 刘沛鸿 BBuf #5646 已完成 oneflow.Tensor.gelu 刘沛鸿 BBuf #5646 已完成 oneflow.nn.Sigmoid 刘沛鸿 BBuf #5646 已完成 oneflow.sigmoid 刘沛鸿 BBuf #5646 已完成 oneflow.Tensor.sigmoid 刘沛鸿 BBuf #5646 已完成 oneflow.nn.Hardsigmoid 郑泽康 BBuf #5646 已完成 oneflow.softmax 张晓雨 zzk #5899 已完成 oneflow.Tensor.softmax 张晓雨 zzk #5899 已完成 oneflow.nn.LogSigmoid 刘沛鸿 BBuf #5646 已完成 oneflow.nn.Softplus 张晓雨 PyTorch有bug oneflow.nn.LogSoftmax 张晓雨 zzk #5899 已完成 oneflow.nn.Mish 郑泽康 PyTorch没有 oneflow.mish 郑泽康 PyTorch没有 oneflow.Tensor.mish 郑泽康 PyTorch没有 oneflow.acosh 张晓雨 郑泽康 #5654 已完成 oneflow.Tensor.acosh 张晓雨 郑泽康 #5654 已完成 oneflow.arccosh 张晓雨 郑泽康 #5654 已完成 oneflow.Tensor.arccosh 张晓雨 郑泽康 #5654 已完成 oneflow.arange 赵露阳 张晓雨 #5666 已完成 oneflow.argwhere 赵露阳 张晓雨 #5666 已完成 oneflow.Tensor.argwhere 赵露阳 张晓雨 #5666 已完成 oneflow.argmax 赵露阳 张晓雨 #5666 已完成 oneflow.Tensor.argmax 赵露阳 张晓雨 #5666 已完成 oneflow.nn.BatchNorm1d 赵露阳 oneflow.nn.BatchNorm2d 石永涛 张晓雨 #5631 已完成 oneflow.nn.ReplicationPad2d 赵露阳 张晓雨 #5666 已完成 oneflow.nn.InstanceNorm1d 赵露阳 张晓雨 #5666 已完成 oneflow.nn.InstanceNorm2d 赵露阳 张晓雨 #5666 已完成 oneflow.nn.InstanceNorm3d 赵露阳 张晓雨 #5666 已完成 oneflow.nn.LayerNorm 赵露阳 张晓雨 #5666 已完成 oneflow.cast 梁德澎 oneflow.Tensor.cast 梁德澎 oneflow.cat 梁德澎 oneflow.ones 梁德澎 oneflow.zeros 梁德澎 oneflow.zeros_like 梁德澎 oneflow.ones_like 梁德澎 oneflow.nn.Module - oneflow.nn.Parameter - oneflow.nn.Sequential - oneflow.nn.ParameterList - oneflow.nn.ParameterDict - oneflow.nn.ModuleList - oneflow.nn.ModuleDict - oneflow.nn.Conv2d 梁德澎 oneflow.nn.ConstantPad2d 梁德澎 oneflow.nn.ConvTranspose2d 梁德澎 oneflow.nn.Dropout 梁德澎 oneflow.eq 任天和 张晓雨 #5599 已完成 oneflow.to 任天和 张晓雨 #5599 已完成 oneflow.Tensor.to 任天和 张晓雨 #5599 已完成 oneflow.equal 任天和 张晓雨 #5599 有BUG,正在解决 oneflow.Tensor.eq 任天和 张晓雨 #5599 已完成 oneflow.exp 任天和 张晓雨 #5599 已完成 oneflow.Tensor.exp 任天和 张晓雨 #5599 已完成 oneflow.erf 任天和 张晓雨 #5599 已完成 oneflow.Tensor.erf 任天和 张晓雨 #5599 已完成 oneflow.erfc 任天和 张晓雨 #5599 已完成 oneflow.Tensor.erfc 任天和 张晓雨 #5599 已完成 oneflow.round 任天和 张晓雨 #5599 已完成 oneflow.Tensor.round 任天和 张晓雨 #5599 已完成 oneflow.Tensor.expand 任天和 张晓雨 #5599 已完成 oneflow.nn.Flatten 任天和 张晓雨 #5599 已完成 oneflow.flatten 任天和 张晓雨 #5599 已完成 oneflow.Tensor.flatten 赵天宇 张晓雨 #5561 已完成 oneflow.gt 赵天宇 张晓雨 #5561 已完成 oneflow.Tensor.gt 赵天宇 张晓雨 #5561 已完成 oneflow.lt 赵天宇 张晓雨 #5561 已完成 oneflow.Tensor.lt 赵天宇 张晓雨 #5561 已完成 oneflow.nn.Identity 赵天宇 张晓雨 #5561 已完成 oneflow.nn.PixelShuffle 赵天宇 张晓雨 #5561 已完成 oneflow.nn.Linear 赵天宇 张晓雨 #5561 已完成 oneflow.nn.CrossEntropyLoss 赵天宇 张晓雨 #5561 有BUG,未解决 oneflow.nn.CTCLoss 黄振华 oneflow.nn.L1Loss oneflow.nn.BCELoss oneflow.nn.NLLLoss oneflow.nn.KLDivLoss oneflow.nn.MSELoss oneflow.nn.MarginRankingLoss oneflow.nn.BCEWithLogitsLoss oneflow.masked_fill 姚迟 oneflow.Tensor.masked_fill 姚迟 oneflow.sum 姚迟 oneflow.Tensor.sum 姚迟 oneflow.min 姚迟 oneflow.Tensor.min 姚迟 oneflow.max 姚迟 oneflow.Tensor.max 姚迟 oneflow.mul 姚迟 oneflow.Tensor.mul 姚迟 oneflow.mean 姚迟 oneflow.Tensor.mean 姚迟 oneflow.sub 姚迟 oneflow.Tensor.sub 姚迟 oneflow.div 姚迟 oneflow.Tensor.div 姚迟 oneflow.var 姚迟 oneflow.Tensor.var 姚迟 oneflow.reciprocal 李春游 oneflow.Tensor.reciprocal 李春游 oneflow.add 李春游 oneflow.Tensor.add 李春游 oneflow.sign 李春游 oneflow.Tensor.sign 李春游 oneflow.sin 张晓雨 郑泽康 #5654 已完成 oneflow.Tensor.sin 张晓雨 郑泽康 #5654 已完成 oneflow.atan 张晓雨 郑泽康 #5654 已完成 oneflow.Tensor.atan 张晓雨 郑泽康 #5654 已完成 oneflow.arctan 张晓雨 郑泽康 #5654 已完成 oneflow.Tensor.arctan 张晓雨 郑泽康 #5654 已完成 oneflow.cos 张晓雨 郑泽康 #5654 已完成 oneflow.Tensor.cos 张晓雨 郑泽康 #5654 已完成 oneflow.log 李春游 oneflow.Tensor.log 李春游 oneflow.sqrt 刘沛宏 张晓雨 #5646 oneflow.Tensor.sqrt 刘沛宏 张晓雨 #5646 oneflow.square 刘沛宏 张晓雨 #5646 oneflow.Tensor.square 刘沛宏 张晓雨 #5646 oneflow.std 刘沛宏 张晓雨 #5646 有bug oneflow.Tensor.std 刘沛宏 张晓雨 #5646 有bug oneflow.pow 张晓雨 郑泽康 #5646 oneflow.Tensor.pow 张晓雨 郑泽康 #5646 oneflow.cosh 张晓雨 姚迟 #5654 已完成 oneflow.Tensor.cosh 张晓雨 郑泽康 #5654 已完成 oneflow.acos 张晓雨 郑泽康 #5654 已完成 oneflow.Tensor.acos 张晓雨 郑泽康 #5654 已完成 oneflow.matmul 张申 张晓雨 #5600 已完成 oneflow.Tensor.matmul 张申 张晓雨 #5600 已完成 oneflow.negative 张申 张晓雨 #5600 已完成 oneflow.neg 张申 张晓雨 #5600 已完成 oneflow.Tensor.negative 张申 张晓雨 #5600 已完成 oneflow.nn.AvgPool1d 张申 张晓雨 #5600 有bug oneflow.nn.AvgPool2d 张申 张晓雨 #5600 有bug oneflow.nn.AvgPool3d 张申 张晓雨 #5600 梯度有bug oneflow.nn.AdaptiveAvgPool2d 张申 张晓雨 #5600 已完成 oneflow.nn.MaxPool1d 张申 张晓雨 #5600 梯度有bug oneflow.nn.MaxPool2d 张申 张晓雨 #5600 梯度有bug oneflow.nn.MaxPool3d 张申 张晓雨 #5600 梯度有bug oneflow.repeat 张申 张晓雨 #5600 pytorch没有 oneflow.Tensor.repeat 张申 张晓雨 #5600 已完成 oneflow.tile 张申 张晓雨 #5600 已完成 oneflow.Tensor.tile 张申 张晓雨 #5600 已完成 oneflow.reshape 张晓雨 赵露阳 #5588 已完成 oneflow.Tensor.reshape 张晓雨 赵露阳 #5588 已完成 oneflow.squeeze 张晓雨 赵露阳 #5588 已完成 oneflow.Tensor.squeeze 张晓雨 赵露阳 #5588 已完成 oneflow.transpose 张晓雨 赵露阳 #5588 已完成 oneflow.Tensor.transpose 张晓雨 赵露阳 #5588 已完成 oneflow.unsqueeze 张晓雨 赵露阳 #5588 已完成 oneflow.Tensor.unsqueeze 张晓雨 赵露阳 #5588 已完成 oneflow.where 张晓雨 oneflow.Tensor.where 张晓雨 oneflow.gather oneflow.Tensor.gather oneflow.nn.Embedding oneflow.Tensor.permute 张晓雨 赵露阳 #5588 已完成 oneflow.nn.Hardswish 张晓雨 赵露阳 #5588 已完成 oneflow.nn.PReLU 张晓雨 王迎港 #5529 oneflow.nn.Hardtanh 张晓雨 赵露阳 #5588 已完成 oneflow.nn.Upsample Pytorh存在一个BUG,不好随机自动测试 oneflow.nn.UpsamplingNearest2d Pytorh存在一个BUG,不好随机自动测试 oneflow.nn.UpsamplingBilinear2d Pytorh存在一个BUG,不好随机自动测试 oneflow.linalg.norm oneflow.Tensor.norm oneflow.floor 张晓雨 郑泽康 #5654 已完成 oneflow.Tensor.floor 张晓雨 郑泽康 #5654 已完成 oneflow.addmm 张晓雨 赵露阳 #5605 oneflow.Tensor.addmm 张晓雨 赵露阳 #5605 oneflow.clamp 张晓雨 赵露阳 #5605 oneflow.Tensor.clamp 张晓雨 赵露阳 #5605 oneflow.clip 张晓雨 赵露阳 #5605 oneflow.Tensor.clip 张晓雨 赵露阳 #5605 oneflow.atanh 张晓雨 郑泽康 #5654 已完成 oneflow.Tensor.atanh 张晓雨 郑泽康 #5654 已完成 oneflow.arctanh 张晓雨 郑泽康 #5654 已完成 oneflow.Tensor.arctanh 张晓雨 郑泽康 #5654 已完成 oneflow.tan 张晓雨 郑泽康 #5654 已完成 oneflow.Tensor.tan 张晓雨 郑泽康 #5654 已完成 oneflow.log1p 黄振华 oneflow.Tensor.log1p 黄振华 oneflow.ceil 张晓雨 赵露阳 #5605 已完成 oneflow.Tensor.ceil 张晓雨 赵露阳 #5605 已完成 oneflow.expm1 张晓雨 赵露阳 #5605 已完成 oneflow.Tensor.expm1 张晓雨 赵露阳 #5605 已完成 oneflow.nn.ReflectionPad2d oneflow.meshgrid 张晓雨 郑泽康 #5899 已完成 oneflow.topk oneflow.Tensor.topk oneflow.nn.GroupNorm oneflow.nn.ZeroPad2d oneflow.tensor_buffer_to_tensor oneflow.tensor_to_tensor_buffer oneflow.new_ones oneflow.Conv1d oneflow.Conv3d 郑泽康 张晓雨 #5327 已完成 oneflow.ConstantPad1d oneflow.chunk oneflow.Tensor.chunk oneflow.masked_select oneflow.Tensor.masked_select oneflow.nn.AdaptiveAvgPool1d oneflow.nn.AdaptiveAvgPool3d oneflow.adaptive_avg_pool1d oneflow.adaptive_avg_pool2d oneflow.adaptive_avg_pool3d oneflow.fmod oneflow.Tensor.fmod oneflow.view oneflow.Tensor.view oneflow.flip oneflow.Tensor.flip oneflow.nn.functional.interpolate oneflow.linalg.vector_norm oneflow.linalg.matrix_norm oneflow.diag oneflow.Tensor.diag oneflow.gather_nd oneflow.scatter_nd oneflow.nn.image.flip oneflow.Tensor.type_as oneflow.Tensor.long oneflow.bernoulli oneflow.in_top_k oneflow.Tensor.in_top_k
此任务的开始依赖PR 5520合并进Master。PR 5520已合并到Master。
之前在issue 5352中的自动测试重构计划是基于老版的自动测试框架,例如无法控制Module传入的参数列表,无法控制只测试有前向的API等等,有较多无法解决的行为并且写法也不自然。新的自动测试框架改善了这些缺点,并提供了一个和原始Pytorch调用API类似的单测写法。所以,我们期望将之前任务分工里面所有API的单测重构迁移到新的自动测试框架中。对于每个API,无论是否迁移过旧版的自动测试方法,均需要用新版的自动测试方法来重写以统一标准。全部迁移完成后,旧版自动测试的接口将删除。
张晓雨,梁德澎,姚迟,赵露阳,郑泽康,任天和,李春游,刘沛宏,张申,也欢迎之江小伙伴们认领
@unittest.skip("need a more relaxed tolerance") @autotest() def test_against_pytorch(test_case): channels = random(1, 6) m = torch.nn.Conv2d(channels, random(1, 6), random(1, 6), stride=random(1, 3) | nothing(), padding=random(1, 3) | nothing(), dilation=random(1, 3) | nothing(), groups=random(1, 3) | nothing(), bias=random() | nothing(), padding_mode=constant('zeros') | nothing() ) m.train(random()) device = random_device() m.to(device) x = random_pytorch_tensor(ndim=4, dim1=channels, dim2=random(1, 8), dim3=random(1, 8)).to(device) y = m(x) return y
可以看到新版的自动测试写法和原始的Pytorch构建一个Module是非常类似的,首先声明Conv2d Module记作m,然后调用m.train(random())设置Module是train还是eval模式,最后构造输入然后进行预测,返回输出Tensor。
m
m.train(random())
train
eval
使用的时候需要注意的是,Module里面你想测试的参数都需要自行指定,比如Conv2d里面的padding指定为random(1, 3) | nothing()表示padding获得的值有可能是自动测试框架随机生成的(对应random(1, 3)),也有可能是Pytorch或者OneFlow框架给出的参数默认值(对应nothing)。
random(1, 3) | nothing()
random(1, 3)
nothing
@autotest() def test_flow_matmul_with_random_data(test_case): k = random(1, 6) x = random_pytorch_tensor(ndim=2, dim1=k) y = random_pytorch_tensor(ndim=2, dim0=k) z = torch.matmul(x, y) return z
对于flow.matmul的测试,我们基于random_pytorch_tensor方法构造了两个随机Tensor x和y,它们的维度分别是[m, k]和[k, n],这些维度的值都是随机生成的。
random_pytorch_tensor
x
y
[m, k]
[k, n]
@autotest() def test_tensor_tan(test_case): x = random_pytorch_tensor().to(random_device()) return x.tan()
和flow.xxx方法的测试类似,这里就不重复了。
flow.xxx
新版的自动测试框架最大的好处就是我们可以像写Pytorch代码那样去写一个测试样例,并且我们可以随意组合参数输入数据的genator来产生各种不同的数据。另外@autotest()这个装饰器还有一些可选参数让我们灵活控制测试,这里看一下定义def autotest(n=20, auto_backward=True, rtol=1e-4, atol=1e-5)。例如对于greater这个函数,它不支持反向,那么我们就设置auto_backward=False,如果某个Module误差会比1e-4还大,那么我们可以调整rtol,再比如我们想控制随机测试的有效次数,则可以通过设置n来解决。
@autotest()
def autotest(n=20, auto_backward=True, rtol=1e-4, atol=1e-5)
greater
auto_backward=False
rtol
n
了解了使用方法之后,这一节我们来了解一下新版自动测试框架的实现思路。我主要从两个方面来讲,首先是讲解自动测试框架中的generators,然后讲解自动测试框架中和Pytorch交互部分的核心实现。
generators
实现在:oneflow/python/test/modules/automated_test_util/generators.py。首先我们从最后几行可以了解到这个文件主要导出了一些generator方法:
oneflow/python/test/modules/automated_test_util/generators.py
generator
__all__ = [ "random_tensor", "random_bool", "random_device", "random", "random_or_nothing", "constant", "nothing", "test_module_against_pytorch", "test_flow_against_pytorch", "test_tensor_against_pytorch", ]
在这个文件中,从def test_against_pytorch这行代码开始到结束的代码在新版的测试框架中已经没用了,只是为了兼容前一个版本大家写的自动测试方法的PR,在大家把之前的自动测试方法迁移到新版后会移除这些代码。这样的话,这个文件就剩下了一些generator的实现。random,random_bool ,constant所有这些generator都继承了generator基类,基类的定义如下:
def test_against_pytorch
random
random_bool
constant
class generator: def __init__(self, children): self.children = children self._value = None def _init(self): self._value = None for x in self.children: x._init() def eval(self): self._init() return self.value() def _calc_value(self): raise NotImplementedError() def value(self): if self._value is None: self._value = self._calc_value() return self._value def size(self): return 1 def __or__(self, other): other = pack(other) return oneof( self, other, possibility=self.size() / (self.size() + other.size()) ) def __ror__(self, other): return self | other def __add__(self, other): return add(self, other) def __radd__(self, other): return self + other def __sub__(self, other): return self + neg(other) def __rsub__(self, other): return neg(self - other) def to(self, annotation): self._to(annotation) for x in self.children: x.to(annotation) return self def _to(self, annotation): pass
所有的generator子类都继承了这个基类,并重写其中的__init__和__calc_value,size等成员函数。例如Nothing就是直接在_calc_value里面返回一个空的class, 实现如下:
__init__
__calc_value
size
Nothing
_calc_value
class
class Nothing: pass class nothing(generator): def __init__(self): super().__init__([]) def _calc_value(self): return Nothing()
再例如,random这个子类定义如下:
class random(generator): def __init__(self, low=1, high=6): self.low = pack(low) self.high = pack(high) super().__init__([self.low, self.high]) self.annotation = None def _to(self, annotation): if self.annotation is not None: return if hasattr(annotation, "__origin__"): # PyTorch _size_2_t and similar types are defined by type variables, # leading to unexpected __args__ and __origin__ # # >>> _size_2_t = Union[T, Tuple[T, T]][int] # >>> _size_2_t.__origin__ # typing.Union[~T, typing.Tuple[~T, ~T]] # # So recreate a new annotation object by repr and eval # # >>> _size_2_t # typing.Union[int, typing.Tuple[int, int]] # >>> _size_2_t_new = eval(repr(annotation)) # >>> _size_2_t_new.__origin__ # typing.Union annotation = eval(repr(annotation)) self.annotation = annotation def _generate(self, annotation): if hasattr(annotation, "__origin__"): if annotation.__origin__ is Union: x = random_util.choice(annotation.__args__) return self._generate(x) if annotation.__origin__ is Tuple or annotation.__origin__ is py_tuple: return [self._generate(x) for x in annotation.__args__] else: raise NotImplementedError( f"Not implemented annotation {annotation} in random, type(annotation.__origin__) is {type(annotation.__origin__)}" ) low, high = self.low.value(), self.high.value() if annotation == int: val = int(rng.integers(low, high)) elif annotation == float: val = float(rng.random() * (high - low) + low) elif annotation == bool: val = random_util.choice([True, False]) else: raise NotImplementedError( f"Not implemented annotation {annotation} in random" ) return val def _calc_value(self): return self._generate(self.annotation) def random_or_nothing(low, high): return oneof(random(low, high), nothing(), possibility=2 / 3)
这里需要注意的是annoation是在下一个阶段传进来的参数的Type,比如int, float, bool,__size_2_t等等。获取了annoation的值这个generator才可以产生真正合法的数据。
annoation
int
float
bool
__size_2_t
这是新版测试框架的核心实现,代码在:oneflow/python/test/modules/automated_test_util/torch_flow_dual_object.py 。我们依然从这个文件导出了什么开始:__all__ = ["torch", "autotest", "random_pytorch_tensor"]
oneflow/python/test/modules/automated_test_util/torch_flow_dual_object.py
__all__ = ["torch", "autotest", "random_pytorch_tensor"]
首先导出了torch,这个torch可以理解为是原始Pytorch的更高层封装,这个封装体现在参数的输入数据可以用上一节的generators来进行组合。导出high level 的torch的代码如下:torch = GetDualObject("", torch_original, flow)。
torch
torch = GetDualObject("", torch_original, flow)
所以核心实现是GetDualObject这个函数,我们来看一下这个函数做了什么?
GetDualObject
class DualObject: def __init__(self, name, pytorch, oneflow): self.name = name self.pytorch = pytorch self.oneflow = oneflow if isinstance(pytorch, torch_original.nn.Module): state_dict = pytorch.state_dict() state_dict = {k: v.detach().cpu().numpy() for k, v in state_dict.items()} oneflow.load_state_dict(state_dict) dual_modules_to_test.append(self) if isinstance(pytorch, torch_original.Tensor): dual_objects_to_test.append(self) def __repr__(self): return f"PyTorch object:\n{self.pytorch}\n\nOneFlow object:\n{self.oneflow}" def __getattr__(self, key): pytorch_attr = getattr(self.pytorch, key) oneflow_attr = getattr(self.oneflow, key) new_name = f"{self.name}.{key}" return GetDualObject(new_name, pytorch_attr, oneflow_attr)
在初始化里面首先传了两个Python对象,分别是Pytorch和OneFlow,在导出high level的torch的时候传的是:torch_original和flow,而导出random_pytorch_tensor的时候传的是:pytorch_tensor和flow_tensor。这里不妨先列出random_pytorch_tensor这个函数的实现:
torch_original
flow
pytorch_tensor
flow_tensor
def random_pytorch_tensor( ndim=None, dim0=1, dim1=None, dim2=None, dim3=None, dim4=None, requires_grad=True ): if isinstance(requires_grad, generator): requires_grad = requires_grad.value() pytorch_tensor = ( random_tensor(ndim, dim0, dim1, dim2, dim3, dim4) .value() .requires_grad_(requires_grad) ) flow_tensor = flow.tensor(pytorch_tensor.detach().cpu().numpy(), requires_grad=True) return GetDualObject("unused", pytorch_tensor, flow_tensor)
可以看到它和导出high level的torch是一样的实现思路, 仍然调用了GetDualObject这个类的构造函数。
继续回到DualObject的实现,我们发现这里分别使用了dual_modules_to_test和dual_objects_to_test这两个列表分别来记录OneFlow和Pytorch的Module和Tensor对象。并重写了__getattr__魔法方法,以Flatten为例子查看一下它具体做了什么?
DualObject
dual_modules_to_test
dual_objects_to_test
__getattr__
def __getattr__(self, key): pytorch_attr = getattr(self.pytorch, key) oneflow_attr = getattr(self.oneflow, key) print(key) # print(pytorch_attr) # print(oneflow_attr) new_name = f"{self.name}.{key}" return GetDualObject(new_name, pytorch_attr, oneflow_attr) # flatten的测试程序 @autotest(auto_backward=False) def test_against_pytorch(test_case): m = torch.nn.Flatten( start_dim=random(1, 6) | nothing(), end_dim=random(1, 6) | nothing() ) m.train(random()) device = random_device() m.to(device) x = random_pytorch_tensor().to(device) y = m(x) return y
然后看一下__getattr__中key的打印结果,重复打印和省略号是因为有20轮测试:
key
nn Flatten train to to nn Flatten train to to nn Flatten train to ...
可以看到由@autotest这个装饰器包起来的程序的Pytorch Module或者函数都重写了这个方法,它将这些Module或者函数的参数取出来用GetDualObject返回一个新的DualObject对象。我们可以打印一下Flatten这个Module对应的DualObject对象是什么?
@autotest
PyTorch object: <bound method Module.train of Flatten(start_dim=1, end_dim=-1)> OneFlow object: <bound method Module.train of Flatten(start_dim=1, end_dim=-1)>
GetDualObject这个函数就是根据传入的Pytorch以及OneFlow对象和它们的名字(两个类的名字必须是相同的,这样才是和Pytorch对齐)来生成一个DualObject对象。GetDualObject这个函数会为Pytorch重写传入的Pytorch以及OneFlow对象的魔法函数,返回一个DualObject对象,这个过程还包含了跳过一些不合法的魔法函数以及检查传入对象的属性是否合法。这里还有一句对于Tensor方法的特判,因为Tensor的API调用方式和其它Module和函数不同。
接下来,就是看一下autotest装饰器的实现了:
def autotest(n=20, auto_backward=True, rtol=1e-4, atol=1e-5): verbose = os.getenv("ONEFLOW_TEST_VERBOSE") is not None def deco(f): @functools.wraps(f) def new_f(test_case): nonlocal n while n > 0: dual_modules_to_test.clear() dual_objects_to_test.clear() try: res = f(test_case) except PyTorchDoesNotSupportError as e: if verbose: print(e) continue # TODO: support types other than Tensor, like torch.Size/flow.Size if res is not None: if not isinstance(res, collections.abc.Sequence): res = [res] for x in res: if auto_backward: if isinstance(x.pytorch, torch_original.Tensor): x.sum().backward() dual_objects_to_test.append(x) for x in dual_modules_to_test: # x.state_dict().values() returns dual object with inconsistent values for key in x.pytorch.state_dict().keys(): dual_objects_to_test.append( GetDualObject( "unused", x.pytorch.state_dict()[key], x.oneflow.state_dict()[key], ) ) for x in dual_objects_to_test: test_case.assertTrue(check_equality(x)) if verbose: print("test passed") n -= 1 return new_f return deco
最后,这个装饰器把包起来的high level的程序执行一遍,获得每个中间Tensor,最后再对每个Tensor进行判断是否在合法的精度范围内。
所以核心的实现是上面产生DualObject的过程,它完成了high-level的Pytorch往原始的pytorch的oneflow的转换,让转换后的张量程序可以直接运行,以获取所有中间结果。
在介绍generators的时候讲到各个generator 的annoations必须获取到具体参数类型才可以产生真正的合法数字,这个过程是在GetDualObject这个函数中的get_args函数完成的,里面的get_generator_value 会遍历所有的参数列表,将dtype传给这些列表中的generator生成真正的合法数据。
get_args
get_generator_value
dtype
def get_generator_value(x): if isinstance(x, generator): return x.value() return x
The text was updated successfully, but these errors were encountered:
No branches or pull requests
The text was updated successfully, but these errors were encountered: