diff --git a/adaptive/_version.py b/adaptive/_version.py
index caefe5705..8d044cc72 100644
--- a/adaptive/_version.py
+++ b/adaptive/_version.py
@@ -4,6 +4,7 @@
 import subprocess
 from collections import namedtuple
 from distutils.command.build_py import build_py as build_py_orig
+from typing import Dict
 
 from setuptools.command.sdist import sdist as sdist_orig
 
@@ -19,7 +20,7 @@
 STATIC_VERSION_FILE = "_static_version.py"
 
 
-def get_version(version_file=STATIC_VERSION_FILE):
+def get_version(version_file: str = STATIC_VERSION_FILE) -> str:
     version_info = get_static_version_info(version_file)
     version = version_info["version"]
     if version == "__use_git__":
@@ -33,7 +34,7 @@ def get_version(version_file=STATIC_VERSION_FILE):
         return version
 
 
-def get_static_version_info(version_file=STATIC_VERSION_FILE):
+def get_static_version_info(version_file: str = STATIC_VERSION_FILE) -> Dict[str, str]:
     version_info = {}
     with open(os.path.join(package_root, version_file), "rb") as f:
         exec(f.read(), {}, version_info)
@@ -44,7 +45,7 @@ def version_is_from_git(version_file=STATIC_VERSION_FILE):
     return get_static_version_info(version_file)["version"] == "__use_git__"
 
 
-def pep440_format(version_info):
+def pep440_format(version_info: Version) -> str:
     release, dev, labels = version_info
 
     version_parts = [release]
@@ -61,7 +62,7 @@ def pep440_format(version_info):
     return "".join(version_parts)
 
 
-def get_version_from_git():
+def get_version_from_git() -> Version:
     try:
         p = subprocess.Popen(
             ["git", "rev-parse", "--show-toplevel"],
diff --git a/adaptive/learner/average_learner.py b/adaptive/learner/average_learner.py
index 39a454485..e13697988 100644
--- a/adaptive/learner/average_learner.py
+++ b/adaptive/learner/average_learner.py
@@ -1,4 +1,5 @@
 from math import sqrt
+from typing import Callable, Dict, List, Optional, Tuple
 
 import numpy as np
 
@@ -30,7 +31,12 @@ class AverageLearner(BaseLearner):
         Number of evaluated points.
     """
 
-    def __init__(self, function, atol=None, rtol=None):
+    def __init__(
+        self,
+        function: Callable,
+        atol: Optional[float] = None,
+        rtol: Optional[float] = None,
+    ) -> None:
         if atol is None and rtol is None:
             raise Exception("At least one of `atol` and `rtol` should be set.")
         if atol is None:
@@ -48,10 +54,10 @@ def __init__(self, function, atol=None, rtol=None):
         self.sum_f_sq = 0
 
     @property
-    def n_requested(self):
+    def n_requested(self) -> int:
         return self.npoints + len(self.pending_points)
 
-    def ask(self, n, tell_pending=True):
+    def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[int], List[float]]:
         points = list(range(self.n_requested, self.n_requested + n))
 
         if any(p in self.data or p in self.pending_points for p in points):
@@ -68,7 +74,7 @@ def ask(self, n, tell_pending=True):
                 self.tell_pending(p)
         return points, loss_improvements
 
-    def tell(self, n, value):
+    def tell(self, n: int, value: float) -> None:
         if n in self.data:
             # The point has already been added before.
             return
@@ -79,16 +85,16 @@ def tell(self, n, value):
         self.sum_f_sq += value ** 2
         self.npoints += 1
 
-    def tell_pending(self, n):
+    def tell_pending(self, n: int) -> None:
         self.pending_points.add(n)
 
     @property
-    def mean(self):
+    def mean(self) -> float:
         """The average of all values in `data`."""
         return self.sum_f / self.npoints
 
     @property
-    def std(self):
+    def std(self) -> float:
         """The corrected sample standard deviation of the values
         in `data`."""
         n = self.npoints
@@ -101,7 +107,7 @@ def std(self):
         return sqrt(numerator / (n - 1))
 
     @cache_latest
-    def loss(self, real=True, *, n=None):
+    def loss(self, real: bool = True, *, n=None) -> float:
         if n is None:
             n = self.npoints if real else self.n_requested
         else:
@@ -113,7 +119,7 @@ def loss(self, real=True, *, n=None):
             standard_error / self.atol, standard_error / abs(self.mean) / self.rtol
         )
 
-    def _loss_improvement(self, n):
+    def _loss_improvement(self, n: int) -> float:
         loss = self.loss()
         if np.isfinite(loss):
             return loss - self.loss(n=self.npoints + n)
@@ -139,8 +145,8 @@ def plot(self):
         vals = hv.Points(vals)
         return hv.operation.histogram(vals, num_bins=num_bins, dimension=1)
 
-    def _get_data(self):
+    def _get_data(self) -> Tuple[Dict[int, float], int, float, float]:
         return (self.data, self.npoints, self.sum_f, self.sum_f_sq)
 
-    def _set_data(self, data):
+    def _set_data(self, data: Tuple[Dict[int, float], int, float, float]) -> None:
         self.data, self.npoints, self.sum_f, self.sum_f_sq = data
diff --git a/adaptive/learner/balancing_learner.py b/adaptive/learner/balancing_learner.py
index ce40afb22..9707ffbf4 100644
--- a/adaptive/learner/balancing_learner.py
+++ b/adaptive/learner/balancing_learner.py
@@ -4,6 +4,7 @@
 from contextlib import suppress
 from functools import partial
 from operator import itemgetter
+from typing import Any, Callable, Dict, List, Set, Tuple, Union
 
 import numpy as np
 
@@ -12,7 +13,7 @@
 from adaptive.utils import cache_latest, named_product, restore
 
 
-def dispatch(child_functions, arg):
+def dispatch(child_functions: List[Callable], arg: Any,) -> Union[Any]:
     index, x = arg
     return child_functions[index](x)
 
@@ -68,7 +69,9 @@ class BalancingLearner(BaseLearner):
     behave in an undefined way. Change the `strategy` in that case.
     """
 
-    def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
+    def __init__(
+        self, learners: List[BaseLearner], *, cdims=None, strategy="loss_improvements"
+    ) -> None:
         self.learners = learners
 
         # Naively we would make 'function' a method, but this causes problems
@@ -89,21 +92,21 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
         self.strategy = strategy
 
     @property
-    def data(self):
+    def data(self) -> Dict[Tuple[int, Any], Any]:
         data = {}
         for i, l in enumerate(self.learners):
             data.update({(i, p): v for p, v in l.data.items()})
         return data
 
     @property
-    def pending_points(self):
+    def pending_points(self) -> Set[Tuple[int, Any]]:
         pending_points = set()
         for i, l in enumerate(self.learners):
             pending_points.update({(i, p) for p in l.pending_points})
         return pending_points
 
     @property
-    def npoints(self):
+    def npoints(self) -> int:
         return sum(l.npoints for l in self.learners)
 
     @property
@@ -135,7 +138,9 @@ def strategy(self, strategy):
                 ' strategy="npoints", or strategy="cycle" is implemented.'
             )
 
-    def _ask_and_tell_based_on_loss_improvements(self, n):
+    def _ask_and_tell_based_on_loss_improvements(
+        self, n: int
+    ) -> Tuple[List[Tuple[int, Any]], List[float]]:
         selected = []  # tuples ((learner_index, point), loss_improvement)
         total_points = [l.npoints + len(l.pending_points) for l in self.learners]
         for _ in range(n):
@@ -158,7 +163,9 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
         points, loss_improvements = map(list, zip(*selected))
         return points, loss_improvements
 
-    def _ask_and_tell_based_on_loss(self, n):
+    def _ask_and_tell_based_on_loss(
+        self, n: int
+    ) -> Tuple[List[Tuple[int, Any]], List[float]]:
         selected = []  # tuples ((learner_index, point), loss_improvement)
         total_points = [l.npoints + len(l.pending_points) for l in self.learners]
         for _ in range(n):
@@ -179,7 +186,9 @@ def _ask_and_tell_based_on_loss(self, n):
         points, loss_improvements = map(list, zip(*selected))
         return points, loss_improvements
 
-    def _ask_and_tell_based_on_npoints(self, n):
+    def _ask_and_tell_based_on_npoints(
+        self, n: int
+    ) -> Tuple[List[Tuple[int, Any]], List[float]]:
         selected = []  # tuples ((learner_index, point), loss_improvement)
         total_points = [l.npoints + len(l.pending_points) for l in self.learners]
         for _ in range(n):
@@ -195,7 +204,9 @@ def _ask_and_tell_based_on_npoints(self, n):
         points, loss_improvements = map(list, zip(*selected))
         return points, loss_improvements
 
-    def _ask_and_tell_based_on_cycle(self, n):
+    def _ask_and_tell_based_on_cycle(
+        self, n: int
+    ) -> Tuple[List[Tuple[int, Any]], List[float]]:
         points, loss_improvements = [], []
         for _ in range(n):
             index = next(self._cycle)
@@ -206,7 +217,9 @@ def _ask_and_tell_based_on_cycle(self, n):
 
         return points, loss_improvements
 
-    def ask(self, n, tell_pending=True):
+    def ask(
+        self, n: int, tell_pending: bool = True
+    ) -> Tuple[List[Tuple[int, Any]], List[float]]:
         """Chose points for learners."""
         if n == 0:
             return [], []
@@ -217,20 +230,20 @@ def ask(self, n, tell_pending=True):
         else:
             return self._ask_and_tell(n)
 
-    def tell(self, x, y):
+    def tell(self, x: Tuple[int, Any], y: Any,) -> None:
         index, x = x
         self._ask_cache.pop(index, None)
         self._loss.pop(index, None)
         self._pending_loss.pop(index, None)
         self.learners[index].tell(x, y)
 
-    def tell_pending(self, x):
+    def tell_pending(self, x: Tuple[int, Any]) -> None:
         index, x = x
         self._ask_cache.pop(index, None)
         self._loss.pop(index, None)
         self.learners[index].tell_pending(x)
 
-    def _losses(self, real=True):
+    def _losses(self, real: bool = True) -> List[float]:
         losses = []
         loss_dict = self._loss if real else self._pending_loss
 
@@ -242,7 +255,7 @@ def _losses(self, real=True):
         return losses
 
     @cache_latest
-    def loss(self, real=True):
+    def loss(self, real: bool = True) -> Union[float]:
         losses = self._losses(real)
         return max(losses)
 
@@ -325,7 +338,9 @@ def remove_unfinished(self):
             learner.remove_unfinished()
 
     @classmethod
-    def from_product(cls, f, learner_type, learner_kwargs, combos):
+    def from_product(
+        cls, f, learner_type, learner_kwargs, combos
+    ) -> "BalancingLearner":
         """Create a `BalancingLearner` with learners of all combinations of
         named variables’ values. The `cdims` will be set correctly, so calling
         `learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
@@ -372,7 +387,7 @@ def from_product(cls, f, learner_type, learner_kwargs, combos):
             learners.append(learner)
         return cls(learners, cdims=arguments)
 
-    def save(self, fname, compress=True):
+    def save(self, fname: Callable, compress: bool = True) -> None:
         """Save the data of the child learners into pickle files
         in a directory.
 
@@ -410,7 +425,7 @@ def save(self, fname, compress=True):
             for l in self.learners:
                 l.save(fname(l), compress=compress)
 
-    def load(self, fname, compress=True):
+    def load(self, fname: Callable, compress: bool = True) -> None:
         """Load the data of the child learners from pickle files
         in a directory.
 
diff --git a/adaptive/learner/base_learner.py b/adaptive/learner/base_learner.py
index f7e3212c9..34c65a167 100644
--- a/adaptive/learner/base_learner.py
+++ b/adaptive/learner/base_learner.py
@@ -1,11 +1,12 @@
 import abc
 from contextlib import suppress
 from copy import deepcopy
+from typing import Any, Callable, Dict
 
 from adaptive.utils import _RequireAttrsABCMeta, load, save
 
 
-def uses_nth_neighbors(n):
+def uses_nth_neighbors(n: int) -> Callable:
     """Decorator to specify how many neighboring intervals the loss function uses.
 
     Wraps loss functions to indicate that they expect intervals together
@@ -84,7 +85,7 @@ class BaseLearner(metaclass=_RequireAttrsABCMeta):
     npoints: int
     pending_points: set
 
-    def tell(self, x, y):
+    def tell(self, x: Any, y) -> None:
         """Tell the learner about a single value.
 
         Parameters
@@ -94,7 +95,7 @@ def tell(self, x, y):
         """
         self.tell_many([x], [y])
 
-    def tell_many(self, xs, ys):
+    def tell_many(self, xs: Any, ys: Any) -> None:
         """Tell the learner about some values.
 
         Parameters
@@ -161,7 +162,7 @@ def copy_from(self, other):
         """
         self._set_data(other._get_data())
 
-    def save(self, fname, compress=True):
+    def save(self, fname: str, compress: bool = True) -> None:
         """Save the data of the learner into a pickle file.
 
         Parameters
@@ -175,7 +176,7 @@ def save(self, fname, compress=True):
         data = self._get_data()
         save(fname, data, compress)
 
-    def load(self, fname, compress=True):
+    def load(self, fname: str, compress: bool = True) -> None:
         """Load the data of a learner from a pickle file.
 
         Parameters
@@ -190,8 +191,8 @@ def load(self, fname, compress=True):
             data = load(fname, compress)
             self._set_data(data)
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, Any]:
         return deepcopy(self.__dict__)
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Dict[str, Any]) -> None:
         self.__dict__ = state
diff --git a/adaptive/learner/data_saver.py b/adaptive/learner/data_saver.py
index 14e246184..2455154e9 100644
--- a/adaptive/learner/data_saver.py
+++ b/adaptive/learner/data_saver.py
@@ -1,5 +1,7 @@
 import functools
 from collections import OrderedDict
+from operator import itemgetter
+from typing import Any, Dict, Tuple, Union
 
 from adaptive.learner.base_learner import BaseLearner
 from adaptive.utils import copy_docstring_from
@@ -25,13 +27,13 @@ class DataSaver:
     >>> learner = DataSaver(_learner, arg_picker=itemgetter('y'))
     """
 
-    def __init__(self, learner, arg_picker):
+    def __init__(self, learner: BaseLearner, arg_picker: itemgetter) -> None:
         self.learner = learner
         self.extra_data = OrderedDict()
         self.function = learner.function
         self.arg_picker = arg_picker
 
-    def __getattr__(self, attr):
+    def __getattr__(self, attr: str) -> Any:
         return getattr(self.learner, attr)
 
     @copy_docstring_from(BaseLearner.tell)
@@ -44,10 +46,17 @@ def tell(self, x, result):
     def tell_pending(self, x):
         self.learner.tell_pending(x)
 
-    def _get_data(self):
+    def _get_data(self,) -> Tuple[Any, OrderedDict]:
         return self.learner._get_data(), self.extra_data
 
-    def _set_data(self, data):
+    def _set_data(
+        self,
+        data: Union[
+            Tuple[OrderedDict, OrderedDict],
+            Tuple[Dict[Union[int, float], float], OrderedDict],
+            Tuple[Tuple[Dict[int, float], int, float, float], OrderedDict],
+        ],
+    ) -> None:
         learner_data, self.extra_data = data
         self.learner._set_data(learner_data)
 
diff --git a/adaptive/learner/integrator_coeffs.py b/adaptive/learner/integrator_coeffs.py
index 7719f601f..8d1c24fbc 100644
--- a/adaptive/learner/integrator_coeffs.py
+++ b/adaptive/learner/integrator_coeffs.py
@@ -2,12 +2,13 @@
 
 from collections import defaultdict
 from fractions import Fraction
+from typing import List, Tuple
 
 import numpy as np
 import scipy.linalg
 
 
-def legendre(n):
+def legendre(n: int) -> List[List[Fraction]]:
     """Return the first n Legendre polynomials.
 
     The polynomials have *standard* normalization, i.e.
@@ -28,7 +29,7 @@ def legendre(n):
     return result
 
 
-def newton(n):
+def newton(n: int) -> np.ndarray:
     """Compute the monomial coefficients of the Newton polynomial over the
     nodes of the n-point Clenshaw-Curtis quadrature rule.
     """
@@ -85,7 +86,7 @@ def newton(n):
     return cf
 
 
-def scalar_product(a, b):
+def scalar_product(a: List[Fraction], b: List[Fraction]) -> Fraction:
     """Compute the polynomial scalar product int_-1^1 dx a(x) b(x).
 
     The args must be sequences of polynomial coefficients.  This
@@ -106,7 +107,7 @@ def scalar_product(a, b):
     return 2 * sum(c[i] / (i + 1) for i in range(0, lc, 2))
 
 
-def calc_bdef(ns):
+def calc_bdef(ns: Tuple[int, int, int, int]) -> List[np.ndarray]:
     """Calculate the decompositions of Newton polynomials (over the nodes
     of the n-point Clenshaw-Curtis quadrature rule) in terms of
     Legandre polynomials.
@@ -132,7 +133,7 @@ def calc_bdef(ns):
     return result
 
 
-def calc_V(x, n):
+def calc_V(x: np.ndarray, n: int) -> np.ndarray:
     V = [np.ones(x.shape), x.copy()]
     for i in range(2, n):
         V.append((2 * i - 1) / i * x * V[-1] - (i - 1) / i * V[-2])
diff --git a/adaptive/learner/integrator_learner.py b/adaptive/learner/integrator_learner.py
index 9c0aeb008..6cb595e6f 100644
--- a/adaptive/learner/integrator_learner.py
+++ b/adaptive/learner/integrator_learner.py
@@ -4,6 +4,7 @@
 from collections import defaultdict
 from math import sqrt
 from operator import attrgetter
+from typing import Callable, List, Optional, Set, Tuple, Union
 
 import numpy as np
 from scipy.linalg import norm
@@ -30,7 +31,7 @@
 )
 
 
-def _downdate(c, nans, depth):
+def _downdate(c: np.ndarray, nans: List[int], depth: int) -> np.ndarray:
     # This is algorithm 5 from the thesis of Pedro Gonnet.
     b = b_def[depth].copy()
     m = ns[depth] - 1
@@ -48,7 +49,7 @@ def _downdate(c, nans, depth):
     return c
 
 
-def _zero_nans(fx):
+def _zero_nans(fx: np.ndarray) -> List[int]:
     """Caution: this function modifies fx."""
     nans = []
     for i in range(len(fx)):
@@ -58,7 +59,7 @@ def _zero_nans(fx):
     return nans
 
 
-def _calc_coeffs(fx, depth):
+def _calc_coeffs(fx: np.ndarray, depth: int) -> np.ndarray:
     """Caution: this function modifies fx."""
     nans = _zero_nans(fx)
     c_new = V_inv[depth] @ fx
@@ -138,7 +139,9 @@ class _Interval:
         "removed",
     ]
 
-    def __init__(self, a, b, depth, rdepth):
+    def __init__(
+        self, a: Union[int, float], b: Union[int, float], depth: int, rdepth: int,
+    ) -> None:
         self.children = []
         self.data = {}
         self.a = a
@@ -150,7 +153,7 @@ def __init__(self, a, b, depth, rdepth):
         self.removed = False
 
     @classmethod
-    def make_first(cls, a, b, depth=2):
+    def make_first(cls, a: int, b: int, depth: int = 2) -> "_Interval":
         ival = _Interval(a, b, depth, rdepth=1)
         ival.ndiv = 0
         ival.parent = None
@@ -158,7 +161,7 @@ def make_first(cls, a, b, depth=2):
         return ival
 
     @property
-    def T(self):
+    def T(self) -> np.ndarray:
         """Get the correct shift matrix.
 
         Should only be called on children of a split interval.
@@ -169,24 +172,24 @@ def T(self):
         assert left != right
         return T_left if left else T_right
 
-    def refinement_complete(self, depth):
+    def refinement_complete(self, depth: int) -> bool:
         """The interval has all the y-values to calculate the intergral."""
         if len(self.data) < ns[depth]:
             return False
         return all(p in self.data for p in self.points(depth))
 
-    def points(self, depth=None):
+    def points(self, depth: Optional[int] = None) -> np.ndarray:
         if depth is None:
             depth = self.depth
         a = self.a
         b = self.b
         return (a + b) / 2 + (b - a) * xi[depth] / 2
 
-    def refine(self):
+    def refine(self) -> "_Interval":
         self.depth += 1
         return self
 
-    def split(self):
+    def split(self) -> List["_Interval"]:
         points = self.points()
         m = points[len(points) // 2]
         ivals = [
@@ -201,10 +204,10 @@ def split(self):
 
         return ivals
 
-    def calc_igral(self):
+    def calc_igral(self) -> None:
         self.igral = (self.b - self.a) * self.c[0] / sqrt(2)
 
-    def update_heuristic_err(self, value):
+    def update_heuristic_err(self, value: float) -> None:
         """Sets the error of an interval using a heuristic (half the error of
         the parent) when the actual error cannot be calculated due to its
         parents not being finished yet. This error is propagated down to its
@@ -217,7 +220,7 @@ def update_heuristic_err(self, value):
                 continue
             child.update_heuristic_err(value / 2)
 
-    def calc_err(self, c_old):
+    def calc_err(self, c_old: np.ndarray) -> float:
         c_new = self.c
         c_diff = np.zeros(max(len(c_old), len(c_new)))
         c_diff[: len(c_old)] = c_old
@@ -229,7 +232,7 @@ def calc_err(self, c_old):
                 child.update_heuristic_err(self.err / 2)
         return c_diff
 
-    def calc_ndiv(self):
+    def calc_ndiv(self) -> None:
         div = self.parent.c00 and self.c00 / self.parent.c00 > 2
         self.ndiv += div
 
@@ -240,7 +243,7 @@ def calc_ndiv(self):
             for child in self.children:
                 child.update_ndiv_recursively()
 
-    def update_ndiv_recursively(self):
+    def update_ndiv_recursively(self) -> None:
         self.ndiv += 1
         if self.ndiv > ndiv_max and 2 * self.ndiv > self.rdepth:
             raise DivergentIntegralError
@@ -248,7 +251,9 @@ def update_ndiv_recursively(self):
         for child in self.children:
             child.update_ndiv_recursively()
 
-    def complete_process(self, depth):
+    def complete_process(
+        self, depth: int
+    ) -> Union[Tuple[bool, bool], Tuple[bool, np.bool_]]:
         """Calculate the integral contribution and error from this interval,
         and update the done leaves of all ancestor intervals."""
         assert self.depth_complete is None or self.depth_complete == depth - 1
@@ -323,7 +328,7 @@ def complete_process(self, depth):
 
         return force_split, remove
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         lst = [
             f"(a, b)=({self.a:.5f}, {self.b:.5f})",
             f"depth={self.depth}",
@@ -335,7 +340,9 @@ def __repr__(self):
 
 
 class IntegratorLearner(BaseLearner):
-    def __init__(self, function, bounds, tol):
+    def __init__(
+        self, function: Callable, bounds: Tuple[int, int], tol: float,
+    ) -> None:
         """
         Parameters
         ----------
@@ -384,10 +391,10 @@ def __init__(self, function, bounds, tol):
         self.first_ival = ival
 
     @property
-    def approximating_intervals(self):
+    def approximating_intervals(self) -> Set["_Interval"]:
         return self.first_ival.done_leaves
 
-    def tell(self, point, value):
+    def tell(self, point: float, value: float) -> None:
         if point not in self.x_mapping:
             raise ValueError(f"Point {point} doesn't belong to any interval")
         self.data[point] = value
@@ -423,7 +430,7 @@ def tell(self, point, value):
     def tell_pending(self):
         pass
 
-    def propagate_removed(self, ival):
+    def propagate_removed(self, ival: "_Interval") -> None:
         def _propagate_removed_down(ival):
             ival.removed = True
             self.ivals.discard(ival)
@@ -433,7 +440,7 @@ def _propagate_removed_down(ival):
 
         _propagate_removed_down(ival)
 
-    def add_ival(self, ival):
+    def add_ival(self, ival: "_Interval") -> None:
         for x in ival.points():
             # Update the mappings
             self.x_mapping[x].add(ival)
@@ -444,7 +451,7 @@ def add_ival(self, ival):
                 self._stack.append(x)
         self.ivals.add(ival)
 
-    def ask(self, n, tell_pending=True):
+    def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[float], List[float]]:
         """Choose points for learners."""
         if not tell_pending:
             with restore(self):
@@ -452,7 +459,7 @@ def ask(self, n, tell_pending=True):
         else:
             return self._ask_and_tell_pending(n)
 
-    def _ask_and_tell_pending(self, n):
+    def _ask_and_tell_pending(self, n: int) -> Tuple[List[float], List[float]]:
         points, loss_improvements = self.pop_from_stack(n)
         n_left = n - len(points)
         while n_left > 0:
@@ -468,7 +475,7 @@ def _ask_and_tell_pending(self, n):
 
         return points, loss_improvements
 
-    def pop_from_stack(self, n):
+    def pop_from_stack(self, n: int) -> Tuple[List[float], List[float]]:
         points = self._stack[:n]
         self._stack = self._stack[n:]
         loss_improvements = [
@@ -479,7 +486,7 @@ def pop_from_stack(self, n):
     def remove_unfinished(self):
         pass
 
-    def _fill_stack(self):
+    def _fill_stack(self) -> List[float]:
         # XXX: to-do if all the ivals have err=inf, take the interval
         # with the lowest rdepth and no children.
         force_split = bool(self.priority_split)
@@ -515,16 +522,16 @@ def _fill_stack(self):
         return self._stack
 
     @property
-    def npoints(self):
+    def npoints(self) -> int:
         """Number of evaluated points."""
         return len(self.data)
 
     @property
-    def igral(self):
+    def igral(self) -> float:
         return sum(i.igral for i in self.approximating_intervals)
 
     @property
-    def err(self):
+    def err(self) -> float:
         if self.approximating_intervals:
             err = sum(i.err for i in self.approximating_intervals)
             if err > sys.float_info.max:
diff --git a/adaptive/learner/learner1D.py b/adaptive/learner/learner1D.py
index e6ace878c..ebdc356f0 100644
--- a/adaptive/learner/learner1D.py
+++ b/adaptive/learner/learner1D.py
@@ -2,10 +2,11 @@
 import math
 from collections.abc import Iterable
 from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
 import numpy as np
-import sortedcollections
-import sortedcontainers
+from sortedcollections.recipes import ItemSortedDict
+from sortedcontainers.sorteddict import SortedDict
 
 from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors
 from adaptive.learner.learnerND import volume
@@ -15,7 +16,7 @@
 
 
 @uses_nth_neighbors(0)
-def uniform_loss(xs, ys):
+def uniform_loss(xs: Tuple[float, float], ys: Tuple[float, float]) -> float:
     """Loss function that samples the domain uniformly.
 
     Works with `~adaptive.Learner1D` only.
@@ -35,7 +36,10 @@ def uniform_loss(xs, ys):
 
 
 @uses_nth_neighbors(0)
-def default_loss(xs, ys):
+def default_loss(
+    xs: Tuple[float, float],
+    ys: Union[Tuple[np.ndarray, np.ndarray], Tuple[float, float]],
+) -> float:
     """Calculate loss on a single interval.
 
     Currently returns the rescaled length of the interval. If one of the
@@ -52,7 +56,7 @@ def default_loss(xs, ys):
 
 
 @uses_nth_neighbors(1)
-def triangle_loss(xs, ys):
+def triangle_loss(xs: Tuple[float], ys: Tuple[Union[float, np.ndarray]]) -> float:
     xs = [x for x in xs if x is not None]
     ys = [y for y in ys if y is not None]
 
@@ -69,7 +73,9 @@ def triangle_loss(xs, ys):
     return sum(vol(pts[i : i + 3]) for i in range(N)) / N
 
 
-def curvature_loss_function(area_factor=1, euclid_factor=0.02, horizontal_factor=0.02):
+def curvature_loss_function(
+    area_factor: float = 1, euclid_factor: float = 0.02, horizontal_factor: float = 0.02
+) -> Callable:
     # XXX: add a doc-string
     @uses_nth_neighbors(1)
     def curvature_loss(xs, ys):
@@ -88,7 +94,7 @@ def curvature_loss(xs, ys):
     return curvature_loss
 
 
-def linspace(x_left, x_right, n):
+def linspace(x_left: float, x_right: float, n: int,) -> List[float]:
     """This is equivalent to
     'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
     but it is 15-30 times faster for small 'n'."""
@@ -100,17 +106,17 @@ def linspace(x_left, x_right, n):
         return [x_left + step * i for i in range(1, n)]
 
 
-def _get_neighbors_from_list(xs):
+def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
     xs = np.sort(xs)
     xs_left = np.roll(xs, 1).tolist()
     xs_right = np.roll(xs, -1).tolist()
     xs_left[0] = None
     xs_right[-1] = None
     neighbors = {x: [x_L, x_R] for x, x_L, x_R in zip(xs, xs_left, xs_right)}
-    return sortedcontainers.SortedDict(neighbors)
+    return SortedDict(neighbors)
 
 
-def _get_intervals(x, neighbors, nth_neighbors):
+def _get_intervals(x: float, neighbors: SortedDict, nth_neighbors: int) -> Any:
     nn = nth_neighbors
     i = neighbors.index(x)
     start = max(0, i - nn - 1)
@@ -163,7 +169,12 @@ class Learner1D(BaseLearner):
     decorator for more information.
     """
 
-    def __init__(self, function, bounds, loss_per_interval=None):
+    def __init__(
+        self,
+        function: Callable,
+        bounds: Tuple[float, float],
+        loss_per_interval: Optional[Callable] = None,
+    ) -> None:
         self.function = function
 
         if hasattr(loss_per_interval, "nth_neighbors"):
@@ -183,8 +194,8 @@ def __init__(self, function, bounds, loss_per_interval=None):
 
         # A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
         # properties.
-        self.neighbors = sortedcontainers.SortedDict()
-        self.neighbors_combined = sortedcontainers.SortedDict()
+        self.neighbors = SortedDict()
+        self.neighbors_combined = SortedDict()
 
         # Bounding box [[minx, maxx], [miny, maxy]].
         self._bbox = [list(bounds), [np.inf, -np.inf]]
@@ -205,7 +216,7 @@ def __init__(self, function, bounds, loss_per_interval=None):
         self._vdim = None
 
     @property
-    def vdim(self):
+    def vdim(self) -> int:
         """Length of the output of ``learner.function``.
         If the output is unsized (when it's a scalar)
         then `vdim = 1`.
@@ -225,35 +236,37 @@ def vdim(self):
         return self._vdim
 
     @property
-    def npoints(self):
+    def npoints(self) -> int:
         """Number of evaluated points."""
         return len(self.data)
 
     @cache_latest
-    def loss(self, real=True):
+    def loss(self, real: bool = True) -> float:
         losses = self.losses if real else self.losses_combined
         if not losses:
             return np.inf
         max_interval, max_loss = losses.peekitem(0)
         return max_loss
 
-    def _scale_x(self, x):
+    def _scale_x(self, x: Optional[float]) -> Optional[float]:
         if x is None:
             return None
         return x / self._scale[0]
 
-    def _scale_y(self, y):
+    def _scale_y(
+        self, y: Optional[Union[float, np.ndarray]]
+    ) -> Optional[Union[float, np.ndarray]]:
         if y is None:
             return None
         y_scale = self._scale[1] or 1
         return y / y_scale
 
-    def _get_point_by_index(self, ind):
+    def _get_point_by_index(self, ind: int) -> Optional[float]:
         if ind < 0 or ind >= len(self.neighbors):
             return None
         return self.neighbors.keys()[ind]
 
-    def _get_loss_in_interval(self, x_left, x_right):
+    def _get_loss_in_interval(self, x_left: float, x_right: float,) -> float:
         assert x_left is not None and x_right is not None
 
         if x_right - x_left < self._dx_eps:
@@ -273,7 +286,9 @@ def _get_loss_in_interval(self, x_left, x_right):
         # we need to compute the loss for this interval
         return self.loss_per_interval(xs_scaled, ys_scaled)
 
-    def _update_interpolated_loss_in_interval(self, x_left, x_right):
+    def _update_interpolated_loss_in_interval(
+        self, x_left: float, x_right: float,
+    ) -> None:
         if x_left is None or x_right is None:
             return
 
@@ -289,7 +304,7 @@ def _update_interpolated_loss_in_interval(self, x_left, x_right):
             self.losses_combined[a, b] = (b - a) * loss / dx
             a = b
 
-    def _update_losses(self, x, real=True):
+    def _update_losses(self, x: float, real: bool = True) -> None:
         """Update all losses that depend on x"""
         # When we add a new point x, we should update the losses
         # (x_left, x_right) are the "real" neighbors of 'x'.
@@ -332,7 +347,7 @@ def _update_losses(self, x, real=True):
             self.losses_combined[x, b] = float("inf")
 
     @staticmethod
-    def _find_neighbors(x, neighbors):
+    def _find_neighbors(x: float, neighbors: SortedDict) -> Any:
         if x in neighbors:
             return neighbors[x]
         pos = neighbors.bisect_left(x)
@@ -341,14 +356,14 @@ def _find_neighbors(x, neighbors):
         x_right = keys[pos] if pos != len(neighbors) else None
         return x_left, x_right
 
-    def _update_neighbors(self, x, neighbors):
+    def _update_neighbors(self, x: float, neighbors: SortedDict) -> None:
         if x not in neighbors:  # The point is new
             x_left, x_right = self._find_neighbors(x, neighbors)
             neighbors[x] = [x_left, x_right]
             neighbors.get(x_left, [None, None])[1] = x
             neighbors.get(x_right, [None, None])[0] = x
 
-    def _update_scale(self, x, y):
+    def _update_scale(self, x: float, y: Union[float, np.ndarray]) -> None:
         """Update the scale with which the x and y-values are scaled.
 
         For a learner where the function returns a single scalar the scale
@@ -375,7 +390,7 @@ def _update_scale(self, x, y):
                 self._bbox[1][1] = max(self._bbox[1][1], y)
                 self._scale[1] = self._bbox[1][1] - self._bbox[1][0]
 
-    def tell(self, x, y):
+    def tell(self, x: float, y: Union[float, np.ndarray]) -> None:
         if x in self.data:
             # The point is already evaluated before
             return
@@ -410,7 +425,7 @@ def tell(self, x, y):
 
             self._oldscale = deepcopy(self._scale)
 
-    def tell_pending(self, x):
+    def tell_pending(self, x: float) -> None:
         if x in self.data:
             # The point is already evaluated before
             return
@@ -418,7 +433,7 @@ def tell_pending(self, x):
         self._update_neighbors(x, self.neighbors_combined)
         self._update_losses(x, real=False)
 
-    def tell_many(self, xs, ys, *, force=False):
+    def tell_many(self, xs: List[float], ys: List[Any], *, force=False) -> None:
         if not force and not (len(xs) > 0.5 * len(self.data) and len(xs) > 2):
             # Only run this more efficient method if there are
             # at least 2 points and the amount of points added are
@@ -486,7 +501,7 @@ def tell_many(self, xs, ys, *, force=False):
                 # have an inf loss.
                 self._update_interpolated_loss_in_interval(*ival)
 
-    def ask(self, n, tell_pending=True):
+    def ask(self, n: int, tell_pending: bool = True) -> Any:
         """Return 'n' points that are expected to maximally reduce the loss."""
         points, loss_improvements = self._ask_points_without_adding(n)
 
@@ -496,7 +511,7 @@ def ask(self, n, tell_pending=True):
 
         return points, loss_improvements
 
-    def _ask_points_without_adding(self, n):
+    def _ask_points_without_adding(self, n: int) -> Any:
         """Return 'n' points that are expected to maximally reduce the loss.
         Without altering the state of the learner"""
         # Find out how to divide the n points over the intervals
@@ -574,7 +589,7 @@ def _ask_points_without_adding(self, n):
 
         return points, loss_improvements
 
-    def _loss(self, mapping, ival):
+    def _loss(self, mapping: ItemSortedDict, ival: Any) -> Any:
         loss = mapping[ival]
         return finite_loss(ival, loss, self._scale[0])
 
@@ -613,29 +628,29 @@ def plot(self, *, scatter_or_line="scatter"):
 
         return p.redim(x=dict(range=plot_bounds))
 
-    def remove_unfinished(self):
+    def remove_unfinished(self) -> None:
         self.pending_points = set()
         self.losses_combined = deepcopy(self.losses)
         self.neighbors_combined = deepcopy(self.neighbors)
 
-    def _get_data(self):
+    def _get_data(self) -> Dict[float, float]:
         return self.data
 
-    def _set_data(self, data):
+    def _set_data(self, data: Dict[float, float]) -> None:
         if data:
             self.tell_many(*zip(*data.items()))
 
 
-def loss_manager(x_scale):
+def loss_manager(x_scale: float) -> ItemSortedDict:
     def sort_key(ival, loss):
         loss, ival = finite_loss(ival, loss, x_scale)
         return -loss, ival
 
-    sorted_dict = sortedcollections.ItemSortedDict(sort_key)
+    sorted_dict = ItemSortedDict(sort_key)
     return sorted_dict
 
 
-def finite_loss(ival, loss, x_scale):
+def finite_loss(ival: Any, loss: float, x_scale: float) -> Any:
     """Get the socalled finite_loss of an interval in order to be able to
     sort intervals that have infinite loss."""
     # If the loss is infinite we return the
diff --git a/adaptive/learner/learner2D.py b/adaptive/learner/learner2D.py
index 5e322b5b2..9730c5d55 100644
--- a/adaptive/learner/learner2D.py
+++ b/adaptive/learner/learner2D.py
@@ -3,9 +3,11 @@
 from collections import OrderedDict
 from copy import copy
 from math import sqrt
+from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
 
 import numpy as np
 from scipy import interpolate
+from scipy.interpolate.interpnd import LinearNDInterpolator
 
 from adaptive.learner.base_learner import BaseLearner
 from adaptive.learner.triangulation import simplex_volume_in_embedding
@@ -15,7 +17,7 @@
 # Learner2D and helper functions.
 
 
-def deviations(ip):
+def deviations(ip: LinearNDInterpolator) -> List[np.ndarray]:
     """Returns the deviation of the linear estimate.
 
     Is useful when defining custom loss functions.
@@ -52,7 +54,7 @@ def deviation(p, v, g):
     return devs
 
 
-def areas(ip):
+def areas(ip: LinearNDInterpolator) -> np.ndarray:
     """Returns the area per triangle of the triangulation inside
     a `LinearNDInterpolator` instance.
 
@@ -73,7 +75,7 @@ def areas(ip):
     return areas
 
 
-def uniform_loss(ip):
+def uniform_loss(ip: LinearNDInterpolator) -> np.ndarray:
     """Loss function that samples the domain uniformly.
 
     Works with `~adaptive.Learner2D` only.
@@ -104,7 +106,9 @@ def uniform_loss(ip):
     return np.sqrt(areas(ip))
 
 
-def resolution_loss_function(min_distance=0, max_distance=1):
+def resolution_loss_function(
+    min_distance: float = 0, max_distance: float = 1
+) -> Callable:
     """Loss function that is similar to the `default_loss` function, but you
     can set the maximimum and minimum size of a triangle.
 
@@ -143,7 +147,7 @@ def resolution_loss(ip):
     return resolution_loss
 
 
-def minimize_triangle_surface_loss(ip):
+def minimize_triangle_surface_loss(ip: LinearNDInterpolator) -> np.ndarray:
     """Loss function that is similar to the distance loss function in the
     `~adaptive.Learner1D`. The loss is the area spanned by the 3D
     vectors of the vertices.
@@ -189,7 +193,7 @@ def _get_vectors(points):
     return np.linalg.norm(np.cross(a, b) / 2, axis=1)
 
 
-def default_loss(ip):
+def default_loss(ip: LinearNDInterpolator) -> np.ndarray:
     """Loss function that combines `deviations` and `areas` of the triangles.
 
     Works with `~adaptive.Learner2D` only.
@@ -209,7 +213,7 @@ def default_loss(ip):
     return losses
 
 
-def choose_point_in_triangle(triangle, max_badness):
+def choose_point_in_triangle(triangle: np.ndarray, max_badness: int) -> np.ndarray:
     """Choose a new point in inside a triangle.
 
     If the ratio of the longest edge of the triangle squared
@@ -348,7 +352,12 @@ class Learner2D(BaseLearner):
     over each triangle.
     """
 
-    def __init__(self, function, bounds, loss_per_triangle=None):
+    def __init__(
+        self,
+        function: Callable,
+        bounds: Tuple[Tuple[int, int], Tuple[int, int]],
+        loss_per_triangle: Optional[Callable] = None,
+    ) -> None:
         self.ndim = len(bounds)
         self._vdim = None
         self.loss_per_triangle = loss_per_triangle or default_loss
@@ -369,28 +378,28 @@ def __init__(self, function, bounds, loss_per_triangle=None):
         self.stack_size = 10
 
     @property
-    def xy_scale(self):
+    def xy_scale(self) -> np.ndarray:
         xy_scale = self._xy_scale
         if self.aspect_ratio == 1:
             return xy_scale
         else:
             return np.array([xy_scale[0], xy_scale[1] / self.aspect_ratio])
 
-    def _scale(self, points):
+    def _scale(self, points: Any) -> np.ndarray:
         points = np.asarray(points, dtype=float)
         return (points - self.xy_mean) / self.xy_scale
 
-    def _unscale(self, points):
+    def _unscale(self, points: np.ndarray) -> np.ndarray:
         points = np.asarray(points, dtype=float)
         return points * self.xy_scale + self.xy_mean
 
     @property
-    def npoints(self):
+    def npoints(self) -> int:
         """Number of evaluated points."""
         return len(self.data)
 
     @property
-    def vdim(self):
+    def vdim(self) -> int:
         """Length of the output of ``learner.function``.
         If the output is unsized (when it's a scalar)
         then `vdim = 1`.
@@ -406,7 +415,7 @@ def vdim(self):
         return self._vdim or 1
 
     @property
-    def bounds_are_done(self):
+    def bounds_are_done(self) -> bool:
         return not any(
             (p in self.pending_points or p in self._stack) for p in self._bounds_points
         )
@@ -443,7 +452,7 @@ def interpolated_on_grid(self, n=None):
         xs, ys = self._unscale(np.vstack([xs, ys]).T).T
         return xs, ys, zs
 
-    def _data_in_bounds(self):
+    def _data_in_bounds(self) -> Tuple[np.ndarray, np.ndarray]:
         if self.data:
             points = np.array(list(self.data.keys()))
             values = np.array(list(self.data.values()), dtype=float)
@@ -452,7 +461,7 @@ def _data_in_bounds(self):
             return points[inds], values[inds].reshape(-1, self.vdim)
         return np.zeros((0, 2)), np.zeros((0, self.vdim), dtype=float)
 
-    def _data_interp(self):
+    def _data_interp(self) -> Tuple[np.ndarray, np.ndarray]:
         if self.pending_points:
             points = list(self.pending_points)
             if self.bounds_are_done:
@@ -465,7 +474,7 @@ def _data_interp(self):
             return points, values
         return np.zeros((0, 2)), np.zeros((0, self.vdim), dtype=float)
 
-    def _data_combined(self):
+    def _data_combined(self) -> Tuple[np.ndarray, np.ndarray]:
         points, values = self._data_in_bounds()
         if not self.pending_points:
             return points, values
@@ -483,7 +492,7 @@ def ip(self):
         )
         return self.interpolator(scaled=True)
 
-    def interpolator(self, *, scaled=False):
+    def interpolator(self, *, scaled: bool = False) -> LinearNDInterpolator:
         """A `scipy.interpolate.LinearNDInterpolator` instance
         containing the learner's data.
 
@@ -514,7 +523,7 @@ def interpolator(self, *, scaled=False):
             points, values = self._data_in_bounds()
             return interpolate.LinearNDInterpolator(points, values)
 
-    def _interpolator_combined(self):
+    def _interpolator_combined(self) -> LinearNDInterpolator:
         """A `scipy.interpolate.LinearNDInterpolator` instance
         containing the learner's data *and* interpolated data of
         the `pending_points`."""
@@ -524,12 +533,14 @@ def _interpolator_combined(self):
             self._ip_combined = interpolate.LinearNDInterpolator(points, values)
         return self._ip_combined
 
-    def inside_bounds(self, xy):
+    def inside_bounds(self, xy: Tuple[float, float],) -> Union[bool, np.bool_]:
         x, y = xy
         (xmin, xmax), (ymin, ymax) = self.bounds
         return xmin <= x <= xmax and ymin <= y <= ymax
 
-    def tell(self, point, value):
+    def tell(
+        self, point: Tuple[float, float], value: Union[float, Iterable[float]],
+    ) -> None:
         point = tuple(point)
         self.data[point] = value
         if not self.inside_bounds(point):
@@ -538,7 +549,7 @@ def tell(self, point, value):
         self._ip = None
         self._stack.pop(point, None)
 
-    def tell_pending(self, point):
+    def tell_pending(self, point: Tuple[float, float],) -> None:
         point = tuple(point)
         if not self.inside_bounds(point):
             return
@@ -546,7 +557,9 @@ def tell_pending(self, point):
         self._ip_combined = None
         self._stack.pop(point, None)
 
-    def _fill_stack(self, stack_till=1):
+    def _fill_stack(
+        self, stack_till: int = 1
+    ) -> Tuple[List[Tuple[float, float]], List[float]]:
         if len(self.data) + len(self.pending_points) < self.ndim + 1:
             raise ValueError("too few points...")
 
@@ -585,7 +598,7 @@ def _fill_stack(self, stack_till=1):
 
         return points_new, losses_new
 
-    def ask(self, n, tell_pending=True):
+    def ask(self, n: int, tell_pending: bool = True) -> Tuple[np.ndarray, np.ndarray]:
         # Even if tell_pending is False we add the point such that _fill_stack
         # will return new points, later we remove these points if needed.
         points = list(self._stack.keys())
@@ -616,14 +629,14 @@ def ask(self, n, tell_pending=True):
         return points[:n], loss_improvements[:n]
 
     @cache_latest
-    def loss(self, real=True):
+    def loss(self, real: bool = True) -> float:
         if not self.bounds_are_done:
             return np.inf
         ip = self.interpolator(scaled=True) if real else self._interpolator_combined()
         losses = self.loss_per_triangle(ip)
         return losses.max()
 
-    def remove_unfinished(self):
+    def remove_unfinished(self) -> None:
         self.pending_points = set()
         for p in self._bounds_points:
             if p not in self.data:
@@ -697,10 +710,10 @@ def plot(self, n=None, tri_alpha=0):
 
         return im.opts(style=im_opts) * tris.opts(style=tri_opts, **no_hover)
 
-    def _get_data(self):
+    def _get_data(self) -> OrderedDict:
         return self.data
 
-    def _set_data(self, data):
+    def _set_data(self, data: OrderedDict) -> None:
         self.data = data
         # Remove points from stack if they already exist
         for point in copy(self._stack):
diff --git a/adaptive/learner/learnerND.py b/adaptive/learner/learnerND.py
index 39ff6fc0b..61621388d 100644
--- a/adaptive/learner/learnerND.py
+++ b/adaptive/learner/learnerND.py
@@ -3,14 +3,18 @@
 import random
 from collections import OrderedDict
 from collections.abc import Iterable
+from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
 
 import numpy as np
 import scipy.spatial
 from scipy import interpolate
+from scipy.spatial.qhull import ConvexHull
 from sortedcontainers import SortedKeyList
 
 from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors
 from adaptive.learner.triangulation import (
+    Point,
+    Simplex,
     Triangulation,
     circumsphere,
     fast_det,
@@ -21,13 +25,13 @@
 from adaptive.utils import cache_latest, restore
 
 
-def to_list(inp):
+def to_list(inp: float) -> List[float]:
     if isinstance(inp, Iterable):
         return list(inp)
     return [inp]
 
 
-def volume(simplex, ys=None):
+def volume(simplex: List[Tuple[float, float]], ys: None = None,) -> float:
     # Notice the parameter ys is there so you can use this volume method as
     # as loss function
     matrix = np.subtract(simplex[:-1], simplex[-1], dtype=float)
@@ -38,14 +42,14 @@ def volume(simplex, ys=None):
     return vol
 
 
-def orientation(simplex):
+def orientation(simplex: np.ndarray):
     matrix = np.subtract(simplex[:-1], simplex[-1])
     # See https://www.jstor.org/stable/2315353
     sign, _logdet = np.linalg.slogdet(matrix)
     return sign
 
 
-def uniform_loss(simplex, values, value_scale):
+def uniform_loss(simplex: np.ndarray, values: np.ndarray, value_scale: float) -> float:
     """
     Uniform loss.
 
@@ -65,7 +69,7 @@ def uniform_loss(simplex, values, value_scale):
     return volume(simplex)
 
 
-def std_loss(simplex, values, value_scale):
+def std_loss(simplex: np.ndarray, values: np.ndarray, value_scale: float) -> np.ndarray:
     """
     Computes the loss of the simplex based on the standard deviation.
 
@@ -91,7 +95,7 @@ def std_loss(simplex, values, value_scale):
     return r.flat * np.power(vol, 1.0 / dim) + vol
 
 
-def default_loss(simplex, values, value_scale):
+def default_loss(simplex: np.ndarray, values: np.ndarray, value_scale: float) -> float:
     """
     Computes the average of the volumes of the simplex.
 
@@ -116,7 +120,13 @@ def default_loss(simplex, values, value_scale):
 
 
 @uses_nth_neighbors(1)
-def triangle_loss(simplex, values, value_scale, neighbors, neighbor_values):
+def triangle_loss(
+    simplex: np.ndarray,
+    values: np.ndarray,
+    value_scale: float,
+    neighbors: Union[List[Union[None, np.ndarray]], List[None], List[np.ndarray]],
+    neighbor_values: Union[List[Union[None, float]], List[None], List[float]],
+) -> Union[int, float]:
     """
     Computes the average of the volumes of the simplex combined with each
     neighbouring point.
@@ -153,7 +163,7 @@ def triangle_loss(simplex, values, value_scale, neighbors, neighbor_values):
     )
 
 
-def curvature_loss_function(exploration=0.05):
+def curvature_loss_function(exploration: float = 0.05) -> Callable:
     # XXX: add doc-string!
     @uses_nth_neighbors(1)
     def curvature_loss(simplex, values, value_scale, neighbors, neighbor_values):
@@ -190,7 +200,9 @@ def curvature_loss(simplex, values, value_scale, neighbors, neighbor_values):
     return curvature_loss
 
 
-def choose_point_in_simplex(simplex, transform=None):
+def choose_point_in_simplex(
+    simplex: np.ndarray, transform: Optional[np.ndarray] = None,
+) -> np.ndarray:
     """Choose a new point in inside a simplex.
 
     Pick the center of the simplex if the shape is nice (that is, the
@@ -231,7 +243,7 @@ def choose_point_in_simplex(simplex, transform=None):
     return point
 
 
-def _simplex_evaluation_priority(key):
+def _simplex_evaluation_priority(key: Any) -> Any:
     # We round the loss to 8 digits such that losses
     # are equal up to numerical precision will be considered
     # to be equal. This is needed because we want the learner
@@ -291,7 +303,12 @@ class LearnerND(BaseLearner):
     children based on volume.
     """
 
-    def __init__(self, func, bounds, loss_per_simplex=None):
+    def __init__(
+        self,
+        func: Callable,
+        bounds: Union[Tuple[Tuple[float, float], ...], ConvexHull],
+        loss_per_simplex: Optional[Callable] = None,
+    ) -> None:
         self._vdim = None
         self.loss_per_simplex = loss_per_simplex or default_loss
 
@@ -324,12 +341,14 @@ def __init__(self, func, bounds, loss_per_simplex=None):
 
         self.function = func
         self._tri = None
-        self._losses = dict()
+        self._losses: Dict[Simplex, float] = dict()
 
-        self._pending_to_simplex = dict()  # vertex → simplex
+        self._pending_to_simplex: Dict[Point, Simplex] = dict()  # vertex → simplex
 
         # triangulation of the pending points inside a specific simplex
-        self._subtriangulations = dict()  # simplex → triangulation
+        self._subtriangulations: Dict[
+            Simplex, Triangulation
+        ] = dict()  # simplex → triangulation
 
         # scale to unit hypercube
         # for the input
@@ -359,7 +378,7 @@ def __init__(self, func, bounds, loss_per_simplex=None):
         self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
 
     @property
-    def npoints(self):
+    def npoints(self) -> int:
         """Number of evaluated points."""
         return len(self.data)
 
@@ -390,7 +409,7 @@ def _ip(self):
         return interpolate.LinearNDInterpolator(self.points, self.values)
 
     @property
-    def tri(self):
+    def tri(self) -> Optional[Triangulation]:
         """An `adaptive.learner.triangulation.Triangulation` instance
         with all the points of the learner."""
         if self._tri is not None:
@@ -413,11 +432,11 @@ def values(self):
         return np.array(list(self.data.values()), dtype=float)
 
     @property
-    def points(self):
+    def points(self) -> np.ndarray:
         """Get the points from `data` as a numpy array."""
         return np.array(list(self.data.keys()), dtype=float)
 
-    def tell(self, point, value):
+    def tell(self, point: Tuple[float, ...], value: Union[float, np.ndarray],) -> None:
         point = tuple(point)
 
         if point in self.data:
@@ -441,11 +460,11 @@ def tell(self, point, value):
             to_delete, to_add = tri.add_point(point, simplex, transform=self._transform)
             self._update_losses(to_delete, to_add)
 
-    def _simplex_exists(self, simplex):
+    def _simplex_exists(self, simplex: Simplex) -> bool:
         simplex = tuple(sorted(simplex))
         return simplex in self.tri.simplices
 
-    def inside_bounds(self, point):
+    def inside_bounds(self, point: Tuple[float, ...],) -> Union[bool, np.bool_]:
         """Check whether a point is inside the bounds."""
         if hasattr(self, "_interior"):
             return self._interior.find_simplex(point, tol=1e-8) >= 0
@@ -455,7 +474,7 @@ def inside_bounds(self, point):
                 (mn - eps) <= p <= (mx + eps) for p, (mn, mx) in zip(point, self._bbox)
             )
 
-    def tell_pending(self, point, *, simplex=None):
+    def tell_pending(self, point: Tuple[float, ...], *, simplex=None,) -> None:
         point = tuple(point)
         if not self.inside_bounds(point):
             return
@@ -482,7 +501,9 @@ def tell_pending(self, point, *, simplex=None):
                 continue
             self._update_subsimplex_losses(simpl, to_add)
 
-    def _try_adding_pending_point_to_simplex(self, point, simplex):
+    def _try_adding_pending_point_to_simplex(
+        self, point: Point, simplex: Simplex,
+    ) -> Any:
         # try to insert it
         if not self.tri.point_in_simplex(point, simplex):
             return None, None
@@ -494,7 +515,9 @@ def _try_adding_pending_point_to_simplex(self, point, simplex):
         self._pending_to_simplex[point] = simplex
         return self._subtriangulations[simplex].add_point(point)
 
-    def _update_subsimplex_losses(self, simplex, new_subsimplices):
+    def _update_subsimplex_losses(
+        self, simplex: Simplex, new_subsimplices: Set[Simplex]
+    ) -> None:
         loss = self._losses[simplex]
 
         loss_density = loss / self.tri.volume(simplex)
@@ -503,11 +526,11 @@ def _update_subsimplex_losses(self, simplex, new_subsimplices):
             subloss = subtriangulation.volume(subsimplex) * loss_density
             self._simplex_queue.add((subloss, simplex, subsimplex))
 
-    def _ask_and_tell_pending(self, n=1):
+    def _ask_and_tell_pending(self, n: int = 1) -> Any:
         xs, losses = zip(*(self._ask() for _ in range(n)))
         return list(xs), list(losses)
 
-    def ask(self, n, tell_pending=True):
+    def ask(self, n: int, tell_pending: bool = True) -> Any:
         """Chose points for learners."""
         if not tell_pending:
             with restore(self):
@@ -515,7 +538,7 @@ def ask(self, n, tell_pending=True):
         else:
             return self._ask_and_tell_pending(n)
 
-    def _ask_bound_point(self):
+    def _ask_bound_point(self,) -> Tuple[Point, float]:
         # get the next bound point that is still available
         new_point = next(
             p
@@ -525,7 +548,7 @@ def _ask_bound_point(self):
         self.tell_pending(new_point)
         return new_point, np.inf
 
-    def _ask_point_without_known_simplices(self):
+    def _ask_point_without_known_simplices(self,) -> Tuple[Point, float]:
         assert not self._bounds_available
         # pick a random point inside the bounds
         # XXX: change this into picking a point based on volume loss
@@ -540,7 +563,7 @@ def _ask_point_without_known_simplices(self):
         self.tell_pending(p)
         return p, np.inf
 
-    def _pop_highest_existing_simplex(self):
+    def _pop_highest_existing_simplex(self) -> Any:
         # find the simplex with the highest loss, we do need to check that the
         # simplex hasn't been deleted yet
         while len(self._simplex_queue):
@@ -566,7 +589,7 @@ def _pop_highest_existing_simplex(self):
             "  be a simplex available if LearnerND.tri() is not None."
         )
 
-    def _ask_best_point(self):
+    def _ask_best_point(self,) -> Tuple[Point, float]:
         assert self.tri is not None
 
         loss, simplex, subsimplex = self._pop_highest_existing_simplex()
@@ -587,13 +610,13 @@ def _ask_best_point(self):
         return point_new, loss
 
     @property
-    def _bounds_available(self):
+    def _bounds_available(self) -> bool:
         return any(
             (p not in self.pending_points and p not in self.data)
             for p in self._bounds_points
         )
 
-    def _ask(self):
+    def _ask(self,) -> Tuple[Point, float]:
         if self._bounds_available:
             return self._ask_bound_point()  # O(1)
 
@@ -605,7 +628,7 @@ def _ask(self):
 
         return self._ask_best_point()  # O(log N)
 
-    def _compute_loss(self, simplex):
+    def _compute_loss(self, simplex: Simplex) -> float:
         # get the loss
         vertices = self.tri.get_vertices(simplex)
         values = [self.data[tuple(v)] for v in vertices]
@@ -644,7 +667,7 @@ def _compute_loss(self, simplex):
             )
         )
 
-    def _update_losses(self, to_delete: set, to_add: set):
+    def _update_losses(self, to_delete: Set[Simplex], to_add: Set[Simplex]) -> None:
         # XXX: add the points outside the triangulation to this as well
         pending_points_unbound = set()
 
@@ -690,7 +713,7 @@ def _update_losses(self, to_delete: set, to_add: set):
                     simplex, self._subtriangulations[simplex].simplices
                 )
 
-    def _recompute_all_losses(self):
+    def _recompute_all_losses(self) -> None:
         """Recompute all losses and pending losses."""
         # amortized O(N) complexity
         if self.tri is None:
@@ -714,11 +737,11 @@ def _recompute_all_losses(self):
             )
 
     @property
-    def _scale(self):
+    def _scale(self) -> float:
         # get the output scale
         return self._max_value - self._min_value
 
-    def _update_range(self, new_output):
+    def _update_range(self, new_output: Union[List[int], float, np.ndarray]) -> bool:
         if self._min_value is None or self._max_value is None:
             # this is the first point, nothing to do, just set the range
             self._min_value = np.min(new_output)
@@ -754,12 +777,12 @@ def _update_range(self, new_output):
         return False
 
     @cache_latest
-    def loss(self, real=True):
+    def loss(self, real: bool = True) -> float:
         # XXX: compute pending loss if real == False
         losses = self._losses if self.tri is not None else dict()
         return max(losses.values()) if losses else float("inf")
 
-    def remove_unfinished(self):
+    def remove_unfinished(self) -> None:
         # XXX: implement this method
         self.pending_points = set()
         self._subtriangulations = dict()
@@ -769,7 +792,7 @@ def remove_unfinished(self):
     # Plotting related stuff #
     ##########################
 
-    def plot(self, n=None, tri_alpha=0):
+    def plot(self, n: Optional[int] = None, tri_alpha: float = 0):
         """Plot the function we want to learn, only works in 2D.
 
         Parameters
@@ -830,7 +853,7 @@ def plot(self, n=None, tri_alpha=0):
 
         return im.opts(style=im_opts) * tris.opts(style=tri_opts, **no_hover)
 
-    def plot_slice(self, cut_mapping, n=None):
+    def plot_slice(self, cut_mapping: Dict[int, float], n: Optional[int] = None):
         """Plot a 1D or 2D interpolated slice of a N-dimensional function.
 
         Parameters
@@ -900,7 +923,7 @@ def plot_slice(self, cut_mapping, n=None):
         else:
             raise ValueError("Only 1 or 2-dimensional plots can be generated.")
 
-    def plot_3D(self, with_triangulation=False):
+    def plot_3D(self, with_triangulation: bool = False):
         """Plot the learner's data in 3D using plotly.
 
         Does *not* work with the
@@ -982,14 +1005,14 @@ def plot_3D(self, with_triangulation=False):
 
         return plotly.offline.iplot(fig)
 
-    def _get_data(self):
+    def _get_data(self) -> OrderedDict:
         return self.data
 
-    def _set_data(self, data):
+    def _set_data(self, data: OrderedDict) -> None:
         if data:
             self.tell_many(*zip(*data.items()))
 
-    def _get_iso(self, level=0.0, which="surface"):
+    def _get_iso(self, level: float = 0.0, which: str = "surface"):
         if which == "surface":
             if self.ndim != 3 or self.vdim != 1:
                 raise Exception(
@@ -1060,7 +1083,9 @@ def _get_vertex_index(a, b):
 
         return vertices, faces_or_lines
 
-    def plot_isoline(self, level=0.0, n=None, tri_alpha=0):
+    def plot_isoline(
+        self, level: float = 0.0, n: Optional[int] = None, tri_alpha: float = 0
+    ):
         """Plot the isoline at a specific level, only works in 2D.
 
         Parameters
@@ -1100,7 +1125,7 @@ def plot_isoline(self, level=0.0, n=None, tri_alpha=0):
         contour = contour.opts(style=contour_opts)
         return plot * contour
 
-    def plot_isosurface(self, level=0.0, hull_opacity=0.2):
+    def plot_isosurface(self, level: float = 0.0, hull_opacity: float = 0.2):
         """Plots a linearly interpolated isosurface.
 
         This is the 3D analog of an isoline. Does *not* work with the
@@ -1138,7 +1163,7 @@ def plot_isosurface(self, level=0.0, hull_opacity=0.2):
         hull_mesh = self._get_hull_mesh(opacity=hull_opacity)
         return plotly.offline.iplot([isosurface, hull_mesh])
 
-    def _get_hull_mesh(self, opacity=0.2):
+    def _get_hull_mesh(self, opacity: float = 0.2):
         plotly = ensure_plotly()
         hull = scipy.spatial.ConvexHull(self._bounds_points)
 
diff --git a/adaptive/learner/sequence_learner.py b/adaptive/learner/sequence_learner.py
index c7398dfa4..b0807ad37 100644
--- a/adaptive/learner/sequence_learner.py
+++ b/adaptive/learner/sequence_learner.py
@@ -1,5 +1,7 @@
 from copy import copy
+from typing import Any, Callable, Iterable, List, Tuple, Union
 
+import numpy as np
 from sortedcontainers import SortedDict, SortedSet
 
 from adaptive.learner.base_learner import BaseLearner
@@ -15,17 +17,19 @@ class _IgnoreFirstArgument:
     pickable.
     """
 
-    def __init__(self, function):
+    def __init__(self, function: Callable) -> None:
         self.function = function
 
-    def __call__(self, index_point, *args, **kwargs):
+    def __call__(
+        self, index_point: Tuple[int, Union[float, np.ndarray]], *args, **kwargs
+    ) -> float:
         index, point = index_point
         return self.function(point, *args, **kwargs)
 
-    def __getstate__(self):
+    def __getstate__(self) -> Callable:
         return self.function
 
-    def __setstate__(self, function):
+    def __setstate__(self, function: Callable) -> None:
         self.__init__(function)
 
 
@@ -56,7 +60,7 @@ class SequenceLearner(BaseLearner):
     the added benefit of having results in the local kernel already.
     """
 
-    def __init__(self, function, sequence):
+    def __init__(self, function: Callable, sequence: Iterable) -> None:
         self._original_function = function
         self.function = _IgnoreFirstArgument(function)
         self._to_do_indices = SortedSet({i for i, _ in enumerate(sequence)})
@@ -65,7 +69,7 @@ def __init__(self, function, sequence):
         self.data = SortedDict()
         self.pending_points = set()
 
-    def ask(self, n, tell_pending=True):
+    def ask(self, n: int, tell_pending: bool = True) -> Tuple[Any, List[float]]:
         indices = []
         points = []
         loss_improvements = []
@@ -83,17 +87,17 @@ def ask(self, n, tell_pending=True):
 
         return points, loss_improvements
 
-    def _get_data(self):
+    def _get_data(self) -> SortedDict:
         return self.data
 
-    def _set_data(self, data):
+    def _set_data(self, data: SortedDict) -> None:
         if data:
             indices, values = zip(*data.items())
             # the points aren't used by tell, so we can safely pass None
             points = [(i, None) for i in indices]
             self.tell_many(points, values)
 
-    def loss(self, real=True):
+    def loss(self, real: bool = True) -> float:
         if not (self._to_do_indices or self.pending_points):
             return 0
         else:
@@ -105,13 +109,13 @@ def remove_unfinished(self):
             self._to_do_indices.add(i)
         self.pending_points = set()
 
-    def tell(self, point, value):
+    def tell(self, point: Tuple[int, Any], value: Any,) -> None:
         index, point = point
         self.data[index] = value
         self.pending_points.discard(index)
         self._to_do_indices.discard(index)
 
-    def tell_pending(self, point):
+    def tell_pending(self, point: Any) -> None:
         index, point = point
         self.pending_points.add(index)
         self._to_do_indices.discard(index)
@@ -126,5 +130,5 @@ def result(self):
         return list(self.data.values())
 
     @property
-    def npoints(self):
+    def npoints(self) -> int:
         return len(self.data)
diff --git a/adaptive/learner/skopt_learner.py b/adaptive/learner/skopt_learner.py
index 88911d096..7dd698b69 100644
--- a/adaptive/learner/skopt_learner.py
+++ b/adaptive/learner/skopt_learner.py
@@ -1,4 +1,5 @@
 import collections
+from typing import Callable, List, Tuple, Union
 
 import numpy as np
 from skopt import Optimizer
@@ -23,13 +24,13 @@ class SKOptLearner(Optimizer, BaseLearner):
         Arguments to pass to ``skopt.Optimizer``.
     """
 
-    def __init__(self, function, **kwargs):
+    def __init__(self, function: Callable, **kwargs) -> None:
         self.function = function
         self.pending_points = set()
         self.data = collections.OrderedDict()
         super().__init__(**kwargs)
 
-    def tell(self, x, y, fit=True):
+    def tell(self, x: Union[float, List[float]], y: float, fit: bool = True) -> None:
         if isinstance(x, collections.abc.Iterable):
             self.pending_points.discard(tuple(x))
             self.data[tuple(x)] = y
@@ -48,7 +49,7 @@ def remove_unfinished(self):
         pass
 
     @cache_latest
-    def loss(self, real=True):
+    def loss(self, real: bool = True) -> float:
         if not self.models:
             return np.inf
         else:
@@ -58,7 +59,12 @@ def loss(self, real=True):
             # estimator of loss, but it is the cheapest.
             return 1 - model.score(self.Xi, self.yi)
 
-    def ask(self, n, tell_pending=True):
+    def ask(
+        self, n: int, tell_pending: bool = True
+    ) -> Union[
+        Tuple[List[float], List[float]],
+        Tuple[List[List[float]], List[float]],  # XXX: this indicates a bug!
+    ]:
         if not tell_pending:
             raise NotImplementedError(
                 "Asking points is an irreversible "
@@ -72,7 +78,7 @@ def ask(self, n, tell_pending=True):
             return [p[0] for p in points], [self.loss() / n] * n
 
     @property
-    def npoints(self):
+    def npoints(self) -> int:
         """Number of evaluated points."""
         return len(self.Xi)
 
diff --git a/adaptive/learner/triangulation.py b/adaptive/learner/triangulation.py
index 0cc0bdeb9..a59c265e4 100644
--- a/adaptive/learner/triangulation.py
+++ b/adaptive/learner/triangulation.py
@@ -1,14 +1,21 @@
+import collections.abc
 import math
 from collections import Counter
-from collections.abc import Iterable, Sized
 from itertools import chain, combinations
 from math import factorial
+from typing import Any, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 import numpy as np
 import scipy.spatial
 
+SimplexPoints = Union[
+    List[Tuple[float, ...]], np.ndarray
+]  # XXX: check if this is correct
+Simplex = Tuple[int, ...]
+Point = Union[Tuple[float, ...], np.ndarray]  # XXX: check if this is correct
 
-def fast_norm(v):
+
+def fast_norm(v: Union[Tuple[float, ...], np.ndarray]) -> float:
     # notice this method can be even more optimised
     if len(v) == 2:
         return math.sqrt(v[0] * v[0] + v[1] * v[1])
@@ -17,7 +24,9 @@ def fast_norm(v):
     return math.sqrt(np.dot(v, v))
 
 
-def fast_2d_point_in_simplex(point, simplex, eps=1e-8):
+def fast_2d_point_in_simplex(
+    point: Point, simplex: SimplexPoints, eps: float = 1e-8
+) -> Union[bool, np.bool_]:
     (p0x, p0y), (p1x, p1y), (p2x, p2y) = simplex
     px, py = point
 
@@ -31,7 +40,7 @@ def fast_2d_point_in_simplex(point, simplex, eps=1e-8):
     return (t >= -eps) and (s + t <= 1 + eps)
 
 
-def point_in_simplex(point, simplex, eps=1e-8):
+def point_in_simplex(point: Point, simplex: SimplexPoints, eps: float = 1e-8) -> bool:
     if len(point) == 2:
         return fast_2d_point_in_simplex(point, simplex, eps)
 
@@ -42,7 +51,7 @@ def point_in_simplex(point, simplex, eps=1e-8):
     return all(alpha > -eps) and sum(alpha) < 1 + eps
 
 
-def fast_2d_circumcircle(points):
+def fast_2d_circumcircle(points: Iterable[Point]) -> Tuple[Tuple[float, float], float]:
     """Compute the center and radius of the circumscribed circle of a triangle
 
     Parameters
@@ -78,7 +87,9 @@ def fast_2d_circumcircle(points):
     return (x + points[0][0], y + points[0][1]), radius
 
 
-def fast_3d_circumcircle(points):
+def fast_3d_circumcircle(
+    points: Iterable[Point],
+) -> Tuple[Tuple[float, float, float], float]:
     """Compute the center and radius of the circumscribed shpere of a simplex.
 
     Parameters
@@ -118,7 +129,7 @@ def fast_3d_circumcircle(points):
     return center, radius
 
 
-def fast_det(matrix):
+def fast_det(matrix: np.ndarray) -> float:
     matrix = np.asarray(matrix, dtype=float)
     if matrix.shape == (2, 2):
         return matrix[0][0] * matrix[1][1] - matrix[1][0] * matrix[0][1]
@@ -129,7 +140,7 @@ def fast_det(matrix):
         return np.linalg.det(matrix)
 
 
-def circumsphere(pts):
+def circumsphere(pts: np.ndarray) -> Tuple[Tuple[float, ...], float]:
     dim = len(pts) - 1
     if dim == 2:
         return fast_2d_circumcircle(pts)
@@ -155,7 +166,7 @@ def circumsphere(pts):
     return tuple(center), radius
 
 
-def orientation(face, origin):
+def orientation(face: np.ndarray, origin: np.ndarray) -> int:
     """Compute the orientation of the face with respect to a point, origin.
 
     Parameters
@@ -181,11 +192,13 @@ def orientation(face, origin):
     return sign
 
 
-def is_iterable_and_sized(obj):
-    return isinstance(obj, Iterable) and isinstance(obj, Sized)
+def is_iterable_and_sized(obj: Any) -> bool:
+    return isinstance(obj, collections.abc.Iterable) and isinstance(
+        obj, collections.abc.Sized
+    )
 
 
-def simplex_volume_in_embedding(vertices) -> float:
+def simplex_volume_in_embedding(vertices: Iterable[Point]) -> float:
     """Calculate the volume of a simplex in a higher dimensional embedding.
     That is: dim > len(vertices) - 1. For example if you would like to know the
     surface area of a triangle in a 3d space.
@@ -266,7 +279,7 @@ class Triangulation:
         or more simplices in the
     """
 
-    def __init__(self, coords):
+    def __init__(self, coords: Iterable[Point]) -> None:
         if not is_iterable_and_sized(coords):
             raise TypeError("Please provide a 2-dimensional list of points")
         coords = list(coords)
@@ -294,10 +307,10 @@ def __init__(self, coords):
                 "(the points are linearly dependent)"
             )
 
-        self.vertices = list(coords)
-        self.simplices = set()
+        self.vertices: List[Point] = list(coords)
+        self.simplices: Set[Simplex] = set()
         # initialise empty set for each vertex
-        self.vertex_to_simplices = [set() for _ in coords]
+        self.vertex_to_simplices: List[Set[Simplex]] = [set() for _ in coords]
 
         # find a Delaunay triangulation to start with, then we will throw it
         # away and continue with our own algorithm
@@ -305,27 +318,29 @@ def __init__(self, coords):
         for simplex in initial_tri.simplices:
             self.add_simplex(simplex)
 
-    def delete_simplex(self, simplex):
+    def delete_simplex(self, simplex: Simplex) -> None:
         simplex = tuple(sorted(simplex))
         self.simplices.remove(simplex)
         for vertex in simplex:
             self.vertex_to_simplices[vertex].remove(simplex)
 
-    def add_simplex(self, simplex):
+    def add_simplex(self, simplex: Simplex) -> None:
         simplex = tuple(sorted(simplex))
         self.simplices.add(simplex)
         for vertex in simplex:
             self.vertex_to_simplices[vertex].add(simplex)
 
-    def get_vertices(self, indices):
+    def get_vertices(self, indices: Sequence[int]) -> List[Optional[Point]]:
         return [self.get_vertex(i) for i in indices]
 
-    def get_vertex(self, index):
+    def get_vertex(self, index: Optional[int]) -> Optional[Point]:
         if index is None:
             return None
         return self.vertices[index]
 
-    def get_reduced_simplex(self, point, simplex, eps=1e-8) -> list:
+    def get_reduced_simplex(
+        self, point: Point, simplex: Simplex, eps: float = 1e-8
+    ) -> list:
         """Check whether vertex lies within a simplex.
 
         Returns
@@ -350,11 +365,13 @@ def get_reduced_simplex(self, point, simplex, eps=1e-8) -> list:
 
         return [simplex[i] for i in result]
 
-    def point_in_simplex(self, point, simplex, eps=1e-8):
+    def point_in_simplex(
+        self, point: Point, simplex: Simplex, eps: float = 1e-8
+    ) -> bool:
         vertices = self.get_vertices(simplex)
         return point_in_simplex(point, vertices, eps)
 
-    def locate_point(self, point):
+    def locate_point(self, point: Point) -> Simplex:
         """Find to which simplex the point belongs.
 
         Return indices of the simplex containing the point.
@@ -366,10 +383,15 @@ def locate_point(self, point):
         return ()
 
     @property
-    def dim(self):
+    def dim(self) -> int:
         return len(self.vertices[0])
 
-    def faces(self, dim=None, simplices=None, vertices=None):
+    def faces(
+        self,
+        dim: Optional[int] = None,
+        simplices: Optional[Iterable[Simplex]] = None,
+        vertices: Optional[Iterable[int]] = None,
+    ) -> Iterator[Tuple[int, ...]]:
         """Iterator over faces of a simplex or vertex sequence."""
         if dim is None:
             dim = self.dim
@@ -390,11 +412,11 @@ def faces(self, dim=None, simplices=None, vertices=None):
         else:
             return faces
 
-    def containing(self, face):
+    def containing(self, face: Tuple[int, ...]) -> Set[Simplex]:
         """Simplices containing a face."""
         return set.intersection(*(self.vertex_to_simplices[i] for i in face))
 
-    def _extend_hull(self, new_vertex, eps=1e-8):
+    def _extend_hull(self, new_vertex: Point, eps: float = 1e-8) -> Set[Simplex]:
         # count multiplicities in order to get all hull faces
         multiplicities = Counter(face for face in self.faces())
         hull_faces = [face for face, count in multiplicities.items() if count == 1]
@@ -434,7 +456,9 @@ def _extend_hull(self, new_vertex, eps=1e-8):
 
         return new_simplices
 
-    def circumscribed_circle(self, simplex, transform):
+    def circumscribed_circle(
+        self, simplex: Simplex, transform: np.ndarray
+    ) -> Tuple[Tuple[float, ...], float]:
         """Compute the center and radius of the circumscribed circle of a simplex.
 
         Parameters
@@ -450,7 +474,9 @@ def circumscribed_circle(self, simplex, transform):
         pts = np.dot(self.get_vertices(simplex), transform)
         return circumsphere(pts)
 
-    def point_in_cicumcircle(self, pt_index, simplex, transform):
+    def point_in_cicumcircle(
+        self, pt_index: int, simplex: Simplex, transform: np.ndarray
+    ) -> bool:
         # return self.fast_point_in_circumcircle(pt_index, simplex, transform)
         eps = 1e-8
 
@@ -460,10 +486,15 @@ def point_in_cicumcircle(self, pt_index, simplex, transform):
         return np.linalg.norm(center - pt) < (radius * (1 + eps))
 
     @property
-    def default_transform(self):
+    def default_transform(self) -> np.ndarray:
         return np.eye(self.dim)
 
-    def bowyer_watson(self, pt_index, containing_simplex=None, transform=None):
+    def bowyer_watson(
+        self,
+        pt_index: int,
+        containing_simplex: Optional[Simplex] = None,
+        transform: Optional[np.ndarray] = None,
+    ) -> Tuple[Set[Simplex], Set[Simplex]]:
         """Modified Bowyer-Watson point adding algorithm.
 
         Create a hole in the triangulation around the new point,
@@ -523,10 +554,10 @@ def bowyer_watson(self, pt_index, containing_simplex=None, transform=None):
         new_triangles = self.vertex_to_simplices[pt_index]
         return bad_triangles - new_triangles, new_triangles - bad_triangles
 
-    def _simplex_is_almost_flat(self, simplex):
+    def _simplex_is_almost_flat(self, simplex: Simplex) -> bool:
         return self._relative_volume(simplex) < 1e-8
 
-    def _relative_volume(self, simplex):
+    def _relative_volume(self, simplex: Simplex) -> float:
         """Compute the volume of a simplex divided by the average (Manhattan)
         distance of its vertices. The advantage of this is that the relative
         volume is only dependent on the shape of the simplex and not on the
@@ -537,20 +568,25 @@ def _relative_volume(self, simplex):
         average_edge_length = np.mean(np.abs(vectors))
         return self.volume(simplex) / (average_edge_length ** self.dim)
 
-    def add_point(self, point, simplex=None, transform=None):
+    def add_point(
+        self,
+        point: Point,
+        simplex: Optional[Simplex] = None,
+        transform: Optional[np.ndarray] = None,
+    ) -> Any:
         """Add a new vertex and create simplices as appropriate.
 
         Parameters
         ----------
         point : float vector
             Coordinates of the point to be added.
-        transform : N*N matrix of floats
-            Multiplication matrix to apply to the point (and neighbouring
-            simplices) when running the Bowyer Watson method.
         simplex : tuple of ints, optional
             Simplex containing the point. Empty tuple indicates points outside
             the hull. If not provided, the algorithm costs O(N), so this should
             be used whenever possible.
+        transform : N*N matrix of floats
+            Multiplication matrix to apply to the point (and neighbouring
+            simplices) when running the Bowyer Watson method.
         """
         point = tuple(point)
         if simplex is None:
@@ -586,16 +622,16 @@ def add_point(self, point, simplex=None, transform=None):
             self.vertices.append(point)
             return self.bowyer_watson(pt_index, actual_simplex, transform)
 
-    def volume(self, simplex):
+    def volume(self, simplex: Simplex) -> float:
         prefactor = np.math.factorial(self.dim)
         vertices = np.array(self.get_vertices(simplex))
         vectors = vertices[1:] - vertices[0]
         return float(abs(fast_det(vectors)) / prefactor)
 
-    def volumes(self):
+    def volumes(self) -> List[float]:
         return [self.volume(sim) for sim in self.simplices]
 
-    def reference_invariant(self):
+    def reference_invariant(self) -> bool:
         """vertex_to_simplices and simplices are compatible."""
         for vertex in range(len(self.vertices)):
             if any(vertex not in tri for tri in self.vertex_to_simplices[vertex]):
@@ -609,26 +645,28 @@ def vertex_invariant(self, vertex):
         """Simplices originating from a vertex don't overlap."""
         raise NotImplementedError
 
-    def get_neighbors_from_vertices(self, simplex):
+    def get_neighbors_from_vertices(self, simplex: Simplex) -> Set[Simplex]:
         return set.union(*[self.vertex_to_simplices[p] for p in simplex])
 
-    def get_face_sharing_neighbors(self, neighbors, simplex):
+    def get_face_sharing_neighbors(
+        self, neighbors: Set[Simplex], simplex: Simplex
+    ) -> Set[Simplex]:
         """Keep only the simplices sharing a whole face with simplex."""
         return {
             simpl for simpl in neighbors if len(set(simpl) & set(simplex)) == self.dim
         }  # they share a face
 
-    def get_simplices_attached_to_points(self, indices):
+    def get_simplices_attached_to_points(self, indices: Simplex) -> Set[Simplex]:
         # Get all simplices that share at least a point with the simplex
         neighbors = self.get_neighbors_from_vertices(indices)
         return self.get_face_sharing_neighbors(neighbors, indices)
 
-    def get_opposing_vertices(self, simplex):
+    def get_opposing_vertices(self, simplex: Simplex) -> Tuple[int, ...]:
         if simplex not in self.simplices:
             raise ValueError("Provided simplex is not part of the triangulation")
         neighbors = self.get_simplices_attached_to_points(simplex)
 
-        def find_opposing_vertex(vertex):
+        def find_opposing_vertex(vertex: int):
             # find the simplex:
             simp = next((x for x in neighbors if vertex not in x), None)
             if simp is None:
@@ -641,7 +679,7 @@ def find_opposing_vertex(vertex):
         return result
 
     @property
-    def hull(self):
+    def hull(self) -> Set[int]:
         """Compute hull from triangulation.
 
         Parameters
diff --git a/adaptive/notebook_integration.py b/adaptive/notebook_integration.py
index 017f951b2..b54afd365 100644
--- a/adaptive/notebook_integration.py
+++ b/adaptive/notebook_integration.py
@@ -76,7 +76,7 @@ def ensure_plotly():
         raise RuntimeError("plotly is not installed; plotting is disabled.")
 
 
-def in_ipynb():
+def in_ipynb() -> bool:
     try:
         # If we are running in IPython, then `get_ipython()` is always a global
         return get_ipython().__class__.__name__ == "ZMQInteractiveShell"
diff --git a/adaptive/runner.py b/adaptive/runner.py
index eb707e725..043f56127 100644
--- a/adaptive/runner.py
+++ b/adaptive/runner.py
@@ -7,14 +7,23 @@
 import time
 import traceback
 import warnings
+from _asyncio import Future, Task
 from contextlib import suppress
+from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
 
+from adaptive.learner.base_learner import BaseLearner
 from adaptive.notebook_integration import in_ipynb, live_info, live_plot
 
+_ThirdPartyClient = []
+_ThirdPartyExecutor = []
+
 try:
     import ipyparallel
+    from ipyparallel.client.asyncresult import AsyncResult
 
     with_ipyparallel = True
+    _ThirdPartyClient.append(ipyparallel.Client)
+    _ThirdPartyExecutor.append(ipyparallel.client.view.ViewExecutor)
 except ModuleNotFoundError:
     with_ipyparallel = False
 
@@ -22,6 +31,8 @@
     import distributed
 
     with_distributed = True
+    _ThirdPartyClient.append(distributed.client.Client)
+    _ThirdPartyExecutor.append(distributed.cfexecutor.ClientExecutor)
 except ModuleNotFoundError:
     with_distributed = False
 
@@ -29,6 +40,7 @@
     import mpi4py.futures
 
     with_mpi4py = True
+    _ThirdPartyExecutor.append(mpi4py.futures.MPIPoolExecutor)
 except ModuleNotFoundError:
     with_mpi4py = False
 
@@ -37,9 +49,84 @@
 
     asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
 
+ThirdPartyClient = Union[tuple(_ThirdPartyClient)]
+ThirdPartyExecutor = Union[tuple(_ThirdPartyExecutor)]
+
+
+# -- Internal executor-related, things
+
+
+class SequentialExecutor(concurrent.Executor):
+    """A trivial executor that runs functions synchronously.
+
+    This executor is mainly for testing.
+    """
+
+    def submit(self, fn: Callable, *args, **kwargs) -> Future:
+        fut = concurrent.Future()
+        try:
+            fut.set_result(fn(*args, **kwargs))
+        except Exception as e:
+            fut.set_exception(e)
+        return fut
+
+    def map(self, fn, *iterable, timeout=None, chunksize=1):
+        return map(fn, iterable)
+
+    def shutdown(self, wait=True):
+        pass
+
+
+def _ensure_executor(
+    executor: Optional[Union[ThirdPartyClient, concurrent.Executor]]
+) -> concurrent.Executor:
+    if executor is None:
+        executor = concurrent.ProcessPoolExecutor()
+
+    if isinstance(executor, concurrent.Executor):
+        return executor
+    elif with_ipyparallel and isinstance(executor, ipyparallel.Client):
+        return executor.executor()
+    elif with_distributed and isinstance(executor, distributed.Client):
+        return executor.get_executor()
+    else:
+        raise TypeError(
+            "Only a concurrent.futures.Executor, distributed.Client,"
+            " or ipyparallel.Client can be used."
+        )
+
+
+def _get_ncores(
+    ex: Union[
+        ThirdPartyExecutor,
+        concurrent.ProcessPoolExecutor,
+        concurrent.ThreadPoolExecutor,
+        SequentialExecutor,
+    ]
+) -> int:
+    """Return the maximum  number of cores that an executor can use."""
+    if with_ipyparallel and isinstance(ex, ipyparallel.client.view.ViewExecutor):
+        return len(ex.view)
+    elif isinstance(
+        ex, (concurrent.ProcessPoolExecutor, concurrent.ThreadPoolExecutor)
+    ):
+        return ex._max_workers  # not public API!
+    elif isinstance(ex, SequentialExecutor):
+        return 1
+    elif with_distributed and isinstance(ex, distributed.cfexecutor.ClientExecutor):
+        return sum(n for n in ex._client.ncores().values())
+    elif with_mpi4py and isinstance(ex, mpi4py.futures.MPIPoolExecutor):
+        ex.bootup()  # wait until all workers are up and running
+        return ex._pool.size  # not public API!
+    else:
+        raise TypeError(f"Cannot get number of cores for {ex.__class__}")
+
+
+# -- Runner definitions
+
 
 class BaseRunner(metaclass=abc.ABCMeta):
-    r"""Base class for runners that use `concurrent.futures.Executors`.
+    r"""Base class for runners that use `concurrent.futures.Executor`\'s.
 
     Parameters
     ----------
@@ -94,16 +181,21 @@ class BaseRunner(metaclass=abc.ABCMeta):
 
     def __init__(
         self,
-        learner,
-        goal,
+        learner: BaseLearner,
+        goal: Callable,
         *,
-        executor=None,
-        ntasks=None,
-        log=False,
-        shutdown_executor=False,
-        retries=0,
-        raise_if_retries_exceeded=True,
-    ):
+        executor: Union[
+            ThirdPartyExecutor,
+            concurrent.ProcessPoolExecutor,
+            concurrent.ThreadPoolExecutor,
+            SequentialExecutor,
+        ] = None,
+        ntasks: int = None,
+        log: bool = False,
+        shutdown_executor: bool = False,
+        retries: int = 0,
+        raise_if_retries_exceeded: bool = True,
+    ) -> None:
 
         self.executor = _ensure_executor(executor)
         self.goal = goal
@@ -117,7 +209,7 @@ def __init__(
         self.shutdown_executor = shutdown_executor or (executor is None)
 
         self.learner = learner
-        self.log = [] if log else None
+        self.log: Optional[list] = [] if log else None
 
         # Timing
         self.start_time = time.time()
@@ -130,7 +222,7 @@ def __init__(
         self.to_retry = {}
         self.tracebacks = {}
 
-    def _get_max_tasks(self):
+    def _get_max_tasks(self) -> int:
         return self._max_tasks or _get_ncores(self.executor)
 
     def _do_raise(self, e, x):
@@ -142,10 +234,10 @@ def _do_raise(self, e, x):
         ) from e
 
     @property
-    def do_log(self):
+    def do_log(self) -> bool:
         return self.log is not None
 
-    def _ask(self, n):
+    def _ask(self, n: int) -> Any:
         points = [
             p for p in self.to_retry.keys() if p not in self.pending_points.values()
         ][:n]
@@ -179,7 +271,14 @@ def overhead(self):
         t_total = self.elapsed_time()
         return (1 - t_function / t_total) * 100
 
-    def _process_futures(self, done_futs):
+    def _process_futures(
+        self,
+        done_futs: Union[
+            Set[Future],
+            Set[AsyncResult],  # XXX: AsyncResult might not be imported
+            Set[Task],
+        ],
+    ) -> None:
         for fut in done_futs:
             x = self.pending_points.pop(fut)
             try:
@@ -200,7 +299,13 @@ def _process_futures(self, done_futs):
                     self.log.append(("tell", x, y))
                 self.learner.tell(x, y)
 
-    def _get_futures(self):
+    def _get_futures(
+        self,
+    ) -> Union[
+        List[Task],
+        List[Future],
+        List[AsyncResult],  # XXX: AsyncResult might not be imported
+    ]:
         # Launch tasks to replace the ones that completed
         # on the last iteration, making sure to fill workers
         # that have started since the last iteration.
@@ -221,7 +326,7 @@ def _get_futures(self):
         futures = list(self.pending_points.keys())
         return futures
 
-    def _remove_unfinished(self):
+    def _remove_unfinished(self) -> List[Future]:
         # remove points with 'None' values from the learner
         self.learner.remove_unfinished()
         # cancel any outstanding tasks
@@ -230,7 +335,7 @@ def _remove_unfinished(self):
             fut.cancel()
         return remaining
 
-    def _cleanup(self):
+    def _cleanup(self) -> None:
         if self.shutdown_executor:
             # XXX: temporary set wait=True for Python 3.7
             # see https://github.com/python-adaptive/adaptive/issues/156
@@ -318,16 +423,21 @@ class BlockingRunner(BaseRunner):
 
     def __init__(
         self,
-        learner,
-        goal,
+        learner: BaseLearner,
+        goal: Callable,
         *,
-        executor=None,
-        ntasks=None,
+        executor: Union[
+            ThirdPartyExecutor,
+            concurrent.ProcessPoolExecutor,
+            concurrent.ThreadPoolExecutor,
+            SequentialExecutor,
+        ] = None,
+        ntasks: Optional[int] = None,
         log=False,
         shutdown_executor=False,
         retries=0,
         raise_if_retries_exceeded=True,
-    ):
+    ) -> None:
         if inspect.iscoroutinefunction(learner.function):
             raise ValueError(
                 "Coroutine functions can only be used " "with 'AsyncRunner'."
@@ -344,10 +454,10 @@ def __init__(
         )
         self._run()
 
-    def _submit(self, x):
+    def _submit(self, x: Union[Tuple[float, ...], float, int]) -> Future:
         return self.executor.submit(self.learner.function, x)
 
-    def _run(self):
+    def _run(self) -> None:
         first_completed = concurrent.FIRST_COMPLETED
 
         if self._get_max_tasks() < 1:
@@ -445,17 +555,22 @@ class AsyncRunner(BaseRunner):
 
     def __init__(
         self,
-        learner,
-        goal=None,
+        learner: BaseLearner,
+        goal: Optional[Callable] = None,
         *,
-        executor=None,
-        ntasks=None,
-        log=False,
-        shutdown_executor=False,
+        executor: Union[
+            ThirdPartyExecutor,
+            concurrent.ProcessPoolExecutor,
+            concurrent.ThreadPoolExecutor,
+            SequentialExecutor,
+        ] = None,
+        ntasks: Optional[int] = None,
+        log: bool = False,
+        shutdown_executor: bool = False,
         ioloop=None,
-        retries=0,
-        raise_if_retries_exceeded=True,
-    ):
+        retries: int = 0,
+        raise_if_retries_exceeded: bool = True,
+    ) -> None:
 
         if goal is None:
 
@@ -508,7 +623,9 @@ def goal(_):
                 "'adaptive.notebook_extension()'"
             )
 
-    def _submit(self, x):
+    def _submit(
+        self, x: Union[Tuple[int, int], int, Tuple[float, float], float]
+    ) -> Union[Task, Future]:
         ioloop = self.ioloop
         if inspect.iscoroutinefunction(self.learner.function):
             return ioloop.create_task(self.learner.function(x))
@@ -573,7 +690,7 @@ def live_info(self, *, update_interval=0.1):
         """
         return live_info(self, update_interval=update_interval)
 
-    async def _run(self):
+    async def _run(self) -> None:
         first_completed = asyncio.FIRST_COMPLETED
 
         if self._get_max_tasks() < 1:
@@ -592,7 +709,7 @@ async def _run(self):
                 await asyncio.wait(remaining)
             self._cleanup()
 
-    def elapsed_time(self):
+    def elapsed_time(self) -> float:
         """Return the total time elapsed since the runner
         was started."""
         if self.task.done():
@@ -605,7 +722,7 @@ def elapsed_time(self):
             end_time = time.time()
         return end_time - self.start_time
 
-    def start_periodic_saving(self, save_kwargs, interval):
+    def start_periodic_saving(self, save_kwargs: Dict[str, Any], interval: int):
         """Periodically save the learner's data.
 
         Parameters
@@ -637,7 +754,7 @@ async def _saver(save_kwargs=save_kwargs, interval=interval):
 Runner = AsyncRunner
 
 
-def simple(learner, goal):
+def simple(learner: BaseLearner, goal: Callable) -> None:
     """Run the learner until the goal is reached.
 
     Requests a single point from the learner, evaluates
@@ -663,7 +780,7 @@ def simple(learner, goal):
             learner.tell(x, y)
 
 
-def replay_log(learner, log):
+def replay_log(learner: BaseLearner, log) -> None:
     """Apply a sequence of method calls to a learner.
 
     This is useful for debugging runners.
@@ -682,7 +799,7 @@ def replay_log(learner, log):
 # --- Useful runner goals
 
 
-def stop_after(*, seconds=0, minutes=0, hours=0):
+def stop_after(*, seconds=0, minutes=0, hours=0) -> Callable:
     """Stop a runner after a specified time.
 
     For example, to specify a runner that should stop after
@@ -714,63 +831,3 @@ def stop_after(*, seconds=0, minutes=0, hours=0):
     """
     stop_time = time.time() + seconds + 60 * minutes + 3600 * hours
     return lambda _: time.time() > stop_time
-
-
-# -- Internal executor-related, things
-
-
-class SequentialExecutor(concurrent.Executor):
-    """A trivial executor that runs functions synchronously.
-
-    This executor is mainly for testing.
-    """
-
-    def submit(self, fn, *args, **kwargs):
-        fut = concurrent.Future()
-        try:
-            fut.set_result(fn(*args, **kwargs))
-        except Exception as e:
-            fut.set_exception(e)
-        return fut
-
-    def map(self, fn, *iterable, timeout=None, chunksize=1):
-        return map(fn, iterable)
-
-    def shutdown(self, wait=True):
-        pass
-
-
-def _ensure_executor(executor):
-    if executor is None:
-        executor = concurrent.ProcessPoolExecutor()
-
-    if isinstance(executor, concurrent.Executor):
-        return executor
-    elif with_ipyparallel and isinstance(executor, ipyparallel.Client):
-        return executor.executor()
-    elif with_distributed and isinstance(executor, distributed.Client):
-        return executor.get_executor()
-    else:
-        raise TypeError(
-            "Only a concurrent.futures.Executor, distributed.Client,"
-            " or ipyparallel.Client can be used."
-        )
-
-
-def _get_ncores(ex):
-    """Return the maximum  number of cores that an executor can use."""
-    if with_ipyparallel and isinstance(ex, ipyparallel.client.view.ViewExecutor):
-        return len(ex.view)
-    elif isinstance(
-        ex, (concurrent.ProcessPoolExecutor, concurrent.ThreadPoolExecutor)
-    ):
-        return ex._max_workers  # not public API!
-    elif isinstance(ex, SequentialExecutor):
-        return 1
-    elif with_distributed and isinstance(ex, distributed.cfexecutor.ClientExecutor):
-        return sum(n for n in ex._client.ncores().values())
-    elif with_mpi4py and isinstance(ex, mpi4py.futures.MPIPoolExecutor):
-        ex.bootup()  # wait until all workers are up and running
-        return ex._pool.size  # not public API!
-    else:
-        raise TypeError(f"Cannot get number of cores for {ex.__class__}")
diff --git a/adaptive/tests/algorithm_4.py b/adaptive/tests/algorithm_4.py
index 102590e3c..1eaa6b5f5 100644
--- a/adaptive/tests/algorithm_4.py
+++ b/adaptive/tests/algorithm_4.py
@@ -2,7 +2,8 @@
 # Copyright 2017 Christoph Groth
 
 from collections import defaultdict
-from fractions import Fraction as Frac
+from fractions import Fraction
+from typing import Callable, List, Tuple, Union
 
 import numpy as np
 from numpy.testing import assert_allclose
@@ -11,7 +12,7 @@
 eps = np.spacing(1)
 
 
-def legendre(n):
+def legendre(n: int) -> List[List[Fraction]]:
     """Return the first n Legendre polynomials.
 
     The polynomials have *standard* normalization, i.e.
@@ -19,12 +20,12 @@ def legendre(n):
 
     The return value is a list of list of fraction.Fraction instances.
     """
-    result = [[Frac(1)], [Frac(0), Frac(1)]]
+    result = [[Fraction(1)], [Fraction(0), Fraction(1)]]
     if n <= 2:
         return result[:n]
     for i in range(2, n):
         # Use Bonnet's recursion formula.
-        new = (i + 1) * [Frac(0)]
+        new = (i + 1) * [Fraction(0)]
         new[1:] = (r * (2 * i - 1) for r in result[-1])
         new[:-2] = (n - r * (i - 1) for n, r in zip(new[:-2], result[-2]))
         new[:] = (n / i for n in new)
@@ -32,7 +33,7 @@ def legendre(n):
     return result
 
 
-def newton(n):
+def newton(n: int) -> np.ndarray:
     """Compute the monomial coefficients of the Newton polynomial over the
     nodes of the n-point Clenshaw-Curtis quadrature rule.
     """
@@ -89,7 +90,7 @@ def newton(n):
     return cf
 
 
-def scalar_product(a, b):
+def scalar_product(a: List[Fraction], b: List[Fraction]) -> Fraction:
     """Compute the polynomial scalar product int_-1^1 dx a(x) b(x).
 
     The args must be sequences of polynomial coefficients.  This
@@ -110,7 +111,7 @@ def scalar_product(a, b):
     return 2 * sum(c[i] / (i + 1) for i in range(0, lc, 2))
 
 
-def calc_bdef(ns):
+def calc_bdef(ns: Tuple[int, int, int, int]) -> List[np.ndarray]:
     """Calculate the decompositions of Newton polynomials (over the nodes
     of the n-point Clenshaw-Curtis quadrature rule) in terms of
     Legandre polynomials.
@@ -123,7 +124,7 @@ def calc_bdef(ns):
     result = []
     for n in ns:
         poly = []
-        a = list(map(Frac, newton(n)))
+        a = list(map(Fraction, newton(n)))
         for b in legs[: n + 1]:
             igral = scalar_product(a, b)
 
@@ -145,7 +146,7 @@ def calc_bdef(ns):
 b_def = calc_bdef(n)
 
 
-def calc_V(xi, n):
+def calc_V(xi: np.ndarray, n: int) -> np.ndarray:
     V = [np.ones(xi.shape), xi.copy()]
     for i in range(2, n):
         V.append((2 * i - 1) / i * xi * V[-1] - (i - 1) / i * V[-2])
@@ -183,7 +184,7 @@ def calc_V(xi, n):
 gamma = np.concatenate([[0, 0], np.sqrt(k[2:] ** 2 / (4 * k[2:] ** 2 - 1))])
 
 
-def _downdate(c, nans, depth):
+def _downdate(c: np.ndarray, nans: List[int], depth: int) -> None:
     # This is algorithm 5 from the thesis of Pedro Gonnet.
     b = b_def[depth].copy()
     m = n[depth] - 1
@@ -200,7 +201,7 @@ def _downdate(c, nans, depth):
         m -= 1
 
 
-def _zero_nans(fx):
+def _zero_nans(fx: np.ndarray) -> List[int]:
     nans = []
     for i in range(len(fx)):
         if not np.isfinite(fx[i]):
@@ -209,7 +210,7 @@ def _zero_nans(fx):
     return nans
 
 
-def _calc_coeffs(fx, depth):
+def _calc_coeffs(fx: np.ndarray, depth: int) -> np.ndarray:
     """Caution: this function modifies fx."""
     nans = _zero_nans(fx)
     c_new = V_inv[depth] @ fx
@@ -220,7 +221,7 @@ def _calc_coeffs(fx, depth):
 
 
 class DivergentIntegralError(ValueError):
-    def __init__(self, msg, igral, err, nr_points):
+    def __init__(self, msg: str, igral: float, err: None, nr_points: int) -> None:
         self.igral = igral
         self.err = err
         self.nr_points = nr_points
@@ -230,19 +231,23 @@ def __init__(self, msg, igral, err, nr_points):
 class _Interval:
     __slots__ = ["a", "b", "c", "fx", "igral", "err", "depth", "rdepth", "ndiv", "c00"]
 
-    def __init__(self, a, b, depth, rdepth):
+    def __init__(
+        self, a: Union[int, float], b: Union[int, float], depth: int, rdepth: int
+    ) -> None:
         self.a = a
         self.b = b
         self.depth = depth
         self.rdepth = rdepth
 
-    def points(self):
+    def points(self) -> np.ndarray:
         a = self.a
         b = self.b
         return (a + b) / 2 + (b - a) * xi[self.depth] / 2
 
     @classmethod
-    def make_first(cls, f, a, b, depth=2):
+    def make_first(
+        cls, f: Callable, a: int, b: int, depth: int = 2
+    ) -> Tuple["_Interval", int]:
         ival = _Interval(a, b, depth, 1)
         fx = f(ival.points())
         ival.c = _calc_coeffs(fx, depth)
@@ -251,7 +256,7 @@ def make_first(cls, f, a, b, depth=2):
         ival.ndiv = 0
         return ival, n[depth]
 
-    def calc_igral_and_err(self, c_old):
+    def calc_igral_and_err(self, c_old: np.ndarray) -> float:
         self.c = c_new = _calc_coeffs(self.fx, self.depth)
         c_diff = np.zeros(max(len(c_old), len(c_new)))
         c_diff[: len(c_old)] = c_old
@@ -262,7 +267,9 @@ def calc_igral_and_err(self, c_old):
         self.err = w * c_diff
         return c_diff
 
-    def split(self, f):
+    def split(
+        self, f: Callable
+    ) -> Union[Tuple[Tuple[float, float, float], int], Tuple[List["_Interval"], int]]:
         m = (self.a + self.b) / 2
         f_center = self.fx[(len(self.fx) - 1) // 2]
 
@@ -287,7 +294,7 @@ def split(self, f):
 
         return ivals, nr_points
 
-    def refine(self, f):
+    def refine(self, f: Callable) -> Tuple[np.ndarray, bool, int]:
         """Increase degree of interval."""
         self.depth = depth = self.depth + 1
         points = self.points()
@@ -299,7 +306,9 @@ def refine(self, f):
         return points, split, n[depth] - n[depth - 1]
 
 
-def algorithm_4(f, a, b, tol, N_loops=int(1e9)):
+def algorithm_4(
+    f: Callable, a: int, b: int, tol: float, N_loops: int = int(1e9)
+) -> Tuple[float, float, int, List["_Interval"]]:
     """ALGORITHM_4 evaluates an integral using adaptive quadrature. The
     algorithm uses Clenshaw-Curtis quadrature rules of increasing
     degree in each interval and bisects the interval if either the
@@ -403,29 +412,31 @@ def algorithm_4(f, a, b, tol, N_loops=int(1e9)):
     return igral, err, nr_points, ivals
 
 
-################ Tests ################
+# ############### Tests ################
 
 
-def f0(x):
+def f0(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
     return x * np.sin(1 / x) * np.sqrt(abs(1 - x))
 
 
-def f7(x):
+def f7(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
     return x ** -0.5
 
 
-def f24(x):
+def f24(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
     return np.floor(np.exp(x))
 
 
-def f21(x):
+def f21(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
     y = 0
     for i in range(1, 4):
         y += 1 / np.cosh(20 ** i * (x - 2 * i / 10))
     return y
 
 
-def f63(x, alpha, beta):
+def f63(
+    x: Union[float, np.ndarray], alpha: float, beta: float
+) -> Union[float, np.ndarray]:
     return abs(x - beta) ** alpha
 
 
@@ -433,7 +444,7 @@ def F63(x, alpha, beta):
     return (x - beta) * abs(x - beta) ** alpha / (alpha + 1)
 
 
-def fdiv(x):
+def fdiv(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
     return abs(x - 0.987654321) ** -1.1
 
 
@@ -461,7 +472,9 @@ def test_scalar_product(n=33):
     selection = [0, 5, 7, n - 1]
     for i in selection:
         for j in selection:
-            assert scalar_product(legs[i], legs[j]) == ((i == j) and Frac(2, 2 * i + 1))
+            assert scalar_product(legs[i], legs[j]) == (
+                (i == j) and Fraction(2, 2 * i + 1)
+            )
 
 
 def simple_newton(n):
diff --git a/adaptive/utils.py b/adaptive/utils.py
index 035086205..f58b72d81 100644
--- a/adaptive/utils.py
+++ b/adaptive/utils.py
@@ -5,6 +5,7 @@
 import pickle
 from contextlib import contextmanager
 from itertools import product
+from typing import Any, Callable, Iterator
 
 from atomicwrites import AtomicWriter
 
@@ -16,7 +17,7 @@ def named_product(**items):
 
 
 @contextmanager
-def restore(*learners):
+def restore(*learners) -> Iterator[None]:
     states = [learner.__getstate__() for learner in learners]
     try:
         yield
@@ -25,7 +26,7 @@ def restore(*learners):
             learner.__setstate__(state)
 
 
-def cache_latest(f):
+def cache_latest(f: Callable) -> Callable:
     """Cache the latest return value of the function and add it
     as 'self._cache[f.__name__]'."""
 
@@ -40,7 +41,7 @@ def wrapper(*args, **kwargs):
     return wrapper
 
 
-def save(fname, data, compress=True):
+def save(fname: str, data: Any, compress: bool = True) -> None:
     fname = os.path.expanduser(fname)
     dirname = os.path.dirname(fname)
     if dirname:
@@ -54,14 +55,14 @@ def save(fname, data, compress=True):
         f.write(blob)
 
 
-def load(fname, compress=True):
+def load(fname: str, compress: bool = True):
     fname = os.path.expanduser(fname)
     _open = gzip.open if compress else open
     with _open(fname, "rb") as f:
         return pickle.load(f)
 
 
-def copy_docstring_from(other):
+def copy_docstring_from(other: Callable) -> Callable:
     def decorator(method):
         return functools.wraps(other)(method)