Skip to content

Commit

Permalink
Update notebooks.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Apr 23, 2019
1 parent 7405b50 commit e430f7e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 39 deletions.
25 changes: 20 additions & 5 deletions notebooks/01_user_facing.ipynb
Expand Up @@ -22,13 +22,13 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from unumpy.xnd_backend import XndBackend\n",
"from unumpy.numpy_backend import NumpyBackend\n",
"from unumpy.pytorch_backend import TorchBackend"
"from unumpy.torch_backend import TorchBackend"
]
},
{
Expand All @@ -41,7 +41,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -80,7 +80,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -109,7 +109,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -118,12 +118,27 @@
"text": [
"Using the TorchBackend with coerce=True on a NumPy array: <class 'torch.Tensor'>\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/hameerabbasi/Quansight/uarray/unumpy/torch_backend.py:80: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" ret = torch.tensor(a, dtype=dtype)\n"
]
}
],
"source": [
"with ua.set_backend(TorchBackend, coerce=True):\n",
" print('Using the TorchBackend with coerce=True on a NumPy array: {}'.format(type(np.sum(z))))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
52 changes: 18 additions & 34 deletions notebooks/02_basic_dev_tutorial.ipynb
Expand Up @@ -17,7 +17,8 @@
"metadata": {},
"outputs": [],
"source": [
"from uarray import argument_extractor"
"from uarray import create_multimethod, all_of_type\n",
"import unumpy as unp"
]
},
{
Expand All @@ -30,7 +31,8 @@
" out_args = arrays + args[2:]\n",
" return out_args, kwargs\n",
"\n",
"@argument_extractor(solve_argreplacer)\n",
"@create_multimethod(solve_argreplacer)\n",
"@all_of_type(unp.ndarray)\n",
"def solve(a, b, sym_pos=False, lower=False, overwrite_a=False, overwrite_b=False, debug=None, check_finite=True, assume_a='gen', transposed=False):\n",
" return (a, b)"
]
Expand All @@ -49,8 +51,8 @@
"metadata": {},
"outputs": [],
"source": [
"from uarray import multimethod\n",
"from unumpy.numpy_backend import NumpyBackend\n",
"from uarray import register_implementation\n",
"from unumpy.numpy_backend import NumpyBackend, compat_check as np_cc\n",
"import scipy.linalg as linalg"
]
},
Expand All @@ -71,7 +73,7 @@
}
],
"source": [
"multimethod(NumpyBackend, solve)(linalg.solve)"
"register_implementation(solve, NumpyBackend, compat_check=np_cc)(linalg.solve)"
]
},
{
Expand All @@ -88,8 +90,8 @@
"metadata": {},
"outputs": [],
"source": [
"from uarray import multimethod\n",
"from unumpy.pytorch_backend import TorchBackend\n",
"from uarray import register_implementation\n",
"from unumpy.torch_backend import TorchBackend, compat_check as torch_cc\n",
"import torch"
]
},
Expand Down Expand Up @@ -120,32 +122,7 @@
}
],
"source": [
"multimethod(TorchBackend, solve)(solve_impl)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Backend Code\n",
"## For NumPy\n",
"```python\n",
"from uarray.backend import TypeCheckBackend, register_backend\n",
"import numpy as np\n",
"\n",
"NumpyBackend = TypeCheckBackend((np.ndarray, np.generic), convertor=np.array,\n",
" fallback_types=(tuple, list, int, float, bool))\n",
"register_backend(NumpyBackend)\n",
"```\n",
"\n",
"## For PyTorch\n",
"```python\n",
"import torch\n",
"from uarray.backend import TypeCheckBackend, register_backend\n",
"\n",
"TorchBackend = TypeCheckBackend((torch.Tensor,), convertor=torch.Tensor)\n",
"register_backend(TorchBackend)\n",
"```"
"register_implementation(solve, TorchBackend, compat_check=torch_cc)(solve_impl)"
]
},
{
Expand Down Expand Up @@ -256,12 +233,19 @@
],
"source": [
"import uarray as ua\n",
"with ua.set_backend(NumpyBackend, coerce=None):\n",
"with ua.set_backend(NumpyBackend, coerce=True):\n",
" print(type(solve(a, b)))\n",
" \n",
"with ua.set_backend(TorchBackend, coerce=True):\n",
" print(type(solve(a, b)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit e430f7e

Please sign in to comment.