Skip to content

Commit

Permalink
Remove jaxtyping import.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed May 10, 2024
1 parent 321f345 commit da4b31f
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions harmonic/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
from jaxtyping import Array
import tensorflow_probability as tfp
import distrax

Expand Down Expand Up @@ -211,7 +210,7 @@ class RQSpline(nn.Module):
num_bins: int
spline_range: Sequence[float] = (-10.0, 10.0)
multimodal_base: bool = False
base_centers: Sequence[Array] = None
base_centers: Sequence[jnp.ndarray]] = None

def setup(self):
conditioner = []
Expand Down

0 comments on commit da4b31f

Please sign in to comment.