Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
resolve_writer,
)
from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer
from .meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms
from .meta_obj import MetaObj, get_track_meta, get_track_transforms, set_track_meta, set_track_transforms
from .meta_tensor import MetaTensor
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
Expand Down
3 changes: 2 additions & 1 deletion monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ class MetaObj:

"""

_meta: dict
def __init__(self):
self._meta: dict = self.get_default_meta()

@staticmethod
def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]:
Expand Down
31 changes: 15 additions & 16 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,34 +62,33 @@ class MetaTensor(MetaObj, torch.Tensor):

@staticmethod
def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor:
return torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore

def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = None) -> None:
"""
If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it.
Else, use the default value. Similar for the affine, except this could come from
four places.
Priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`.
"""
out: MetaTensor = torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore
super().__init__()
# set meta
if meta is not None:
out.meta = meta
self.meta = meta
elif isinstance(x, MetaObj):
out.meta = x.meta
else:
out.meta = out.get_default_meta()
self.meta = x.meta
# set the affine
if affine is not None:
if "affine" in out.meta:
warnings.warn("Setting affine, but the applied meta contains an affine. " "This will be overwritten.")
out.affine = affine
elif "affine" in out.meta:
if "affine" in self.meta:
warnings.warn("Setting affine, but the applied meta contains an affine. This will be overwritten.")
self.affine = affine
elif "affine" in self.meta:
pass # nothing to do
elif isinstance(x, MetaTensor):
out.affine = x.affine
self.affine = x.affine
else:
out.affine = out.get_default_affine()
out.affine = out.affine.to(out.device)

return out
self.affine = self.get_default_affine()
self.affine = self.affine.to(self.device)

def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None:
super()._copy_attr(attribute, input_objs, default_fn, deep_copy)
Expand All @@ -113,8 +112,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor:
ret.affine = ret.affine.to(ret.device)
return ret

def get_default_affine(self) -> torch.Tensor:
return torch.eye(4, device=self.device)
def get_default_affine(self, dtype=torch.float64) -> torch.Tensor:
return torch.eye(4, device=self.device, dtype=dtype)

def as_tensor(self) -> torch.Tensor:
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class TestMetaTensor(unittest.TestCase):
@staticmethod
def get_im(shape=None, dtype=None, device=None):
if shape is None:
shape = shape = (1, 10, 8)
shape = (1, 10, 8)
affine = torch.randint(0, 10, (4, 4))
meta = {"fname": rand_string()}
t = torch.rand(shape)
Expand Down