In [1]:
import torch
from cliva_fl.models.relu_mlp import ReLuMLP

In [2]:
l = [784, 1024, 1024, 1024, 10]
A = ReLuMLP(l)
B = ReLuMLP(l)

In [3]:
def tensors_close_custom(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol=1e-07, atol=1e-06) -> bool:
    return torch.lt(torch.abs(torch.sub(tensor1, tensor2)), atol + rtol * torch.abs(tensor2)).all().item()

In [4]:
%%timeit
for m1, m2 in zip(A.layers, B.layers):
    if type(m1) in [torch.nn.Linear]:
        tensors_close_custom(m2.weight, m1.weight)
        tensors_close_custom(m2.bias, m1.bias)

6.17 ms ± 206 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
def tensors_close_torch(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol=1e-07, atol=1e-06) -> bool:
    return torch.allclose(tensor1, tensor2, rtol, atol)

In [6]:
%%timeit
for m1, m2 in zip(A.layers, B.layers):
    if type(m1) in [torch.nn.Linear]:
        tensors_close_torch(m2.weight, m1.weight)
        tensors_close_torch(m2.bias, m1.bias)

10.8 ms ± 256 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
def tensors_close_rand(tensor1: torch.Tensor, tensor2: torch.Tensor, n_check=1, rtol=1e-07, atol=1e-06) -> bool:
    R = torch.round(torch.rand(tensor1.shape[1], n_check))
    tp1 = torch.matmul(tensor1, R)
    tp2 = torch.matmul(tensor2, R)
    return torch.allclose(tp1, tp2, rtol, atol)

In [8]:
%%timeit
for m1, m2 in zip(A.layers, B.layers):
    if type(m1) in [torch.nn.Linear]:
        tensors_close_rand(m2.weight, m1.weight, n_check=9)
        tensors_close_custom(m2.bias, m1.bias)

1.41 ms ± 129 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [9]:
def tensors_close_sum(tensor1: torch.Tensor, tensor2: torch.Tensor, n_check=1, rtol=1e-07, atol=1e-06) -> bool:
    res = True
    for i in range(len(tensor1.shape)):
        res &= tensors_close_custom(tensor1.sum(i), tensor2.sum(i), rtol, atol)
    return res

```
x 0 x 0
0 0 0 0
x 0 x 0
```

In [10]:
%%timeit
for m1, m2 in zip(A.layers, B.layers):
    if type(m1) in [torch.nn.Linear]:
        tensors_close_sum(m2.weight, m1.weight)
        tensors_close_custom(m2.bias, m1.bias)

1.57 ms ± 92.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [11]:
for _ in range(100):
    l = [784, 1024, 1024, 1024, 10]
    A = ReLuMLP(l)
    B = ReLuMLP(l)
    for m1, m2 in zip(A.layers, B.layers):
        if type(m1) in [torch.nn.Linear]:
            assert tensors_close_sum(m2.weight, m1.weight) == tensors_close_torch(m2.weight, m1.weight)