We have a lot of pylance error with TensorType. I think casting everywhere because of this is kind of ugly. If I understand correctly this is just the expected behaviour with torchtyping (is it?). What do we want to do here?
ps. should we use jaxtyping instead of torchtyping? (in the readme Patrick says that it supports static typecheckers)