In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../')

In [3]:
import numpy as np

from magnet_pinn.data.transforms import PhaseShift
from magnet_pinn.data.dataitem import DataItem
from tests.dataloading.transforms.helpers import check_constant_values_not_changed_by_phase_shift, check_complex_number_calculations_in_phase_shift

In [4]:
d = DataItem(
    simulation="children_0_tubes_0_id_0",
    input=np.random.rand(8000, 3).astype(np.float32),
    field=np.random.rand(2, 2, 8000, 3, 8).astype(np.float32),
    subject=np.random.choice([0, 1], size=(8000, 1)).astype(np.bool_),
    positions=np.random.rand(8000, 3).astype(np.float32),
    phase=np.random.rand(8).astype(np.float32),
    mask=np.random.choice([0, 1], size=8).astype(np.bool_),
    coils=np.random.choice([0, 1], size=(8000, 8)).astype(np.float32),
    dtype="float32",
    truncation_coefficients=np.ones(3, dtype=np.float32)
)

In [5]:
result = PhaseShift(num_coils=8, sampling_method="uniform")(d)

In [6]:
check_constant_values_not_changed_by_phase_shift(result, d)

In [7]:
check_complex_number_calculations_in_phase_shift(result, d)

AssertionError: 

In [10]:
coefs_re = np.cos(result.phase) * result.mask
coefs_im = np.sin(result.phase) * result.mask

In [11]:
field_re = d.field[:, 0]
field_im = d.field[:, 1]

In [12]:
field_shifted_re = field_re @ coefs_re - field_im @ coefs_im
field_shifted_im = field_re @ coefs_im + field_im @ coefs_re

In [13]:
expected_field_result = np.stack([field_shifted_re, field_shifted_im], axis=1)

In [14]:
expected_field_result

array([[[[ 1.7302717 ,  3.5270162 ,  2.316648  ],
         [ 3.5336263 ,  2.722262  ,  3.4592142 ],
         [ 2.9691706 ,  2.5594609 ,  2.4527364 ],
         ...,
         [ 2.1708608 ,  3.6709418 ,  2.619008  ],
         [ 3.2647884 ,  3.8266983 ,  3.4480264 ],
         [ 2.4031959 ,  2.1997106 ,  2.68777   ]],

        [[ 0.19592285, -0.96540934, -0.7467792 ],
         [ 0.73843205, -0.00969386, -0.14822876],
         [-1.1526308 , -0.21496224, -0.544299  ],
         ...,
         [ 0.28395557, -0.3347748 , -0.323209  ],
         [-0.3788184 , -0.62655187, -1.139818  ],
         [-0.6905152 , -1.4673979 ,  1.1556563 ]]],


       [[[ 2.9423523 ,  3.345036  ,  1.9349489 ],
         [ 2.3816807 ,  3.3742952 ,  2.4488127 ],
         [ 2.8354166 ,  2.8870258 ,  3.67661   ],
         ...,
         [ 2.5870075 ,  1.5659733 ,  3.9213467 ],
         [ 3.08634   ,  3.9047036 ,  2.6152987 ],
         [ 3.347014  ,  1.4825528 ,  2.1317024 ]],

        [[ 0.08114052,  0.03826773, -0.30205345],


In [15]:
result.field

array([[[[ 1.7302718 ,  3.5270162 ,  2.316648  ],
         [ 3.5336263 ,  2.722262  ,  3.459214  ],
         [ 2.9691706 ,  2.5594609 ,  2.4527361 ],
         ...,
         [ 2.1708608 ,  3.6709418 ,  2.619008  ],
         [ 3.2647882 ,  3.8266983 ,  3.4480267 ],
         [ 2.4031959 ,  2.1997104 ,  2.68777   ]],

        [[ 0.19592285, -0.9654094 , -0.7467792 ],
         [ 0.73843205, -0.00969386, -0.14822865],
         [-1.1526308 , -0.21496212, -0.544299  ],
         ...,
         [ 0.28395557, -0.3347749 , -0.32320905],
         [-0.3788184 , -0.62655175, -1.139818  ],
         [-0.6905152 , -1.4673977 ,  1.1556563 ]]],


       [[[ 2.9423523 ,  3.345036  ,  1.9349489 ],
         [ 2.381681  ,  3.3742952 ,  2.4488127 ],
         [ 2.8354166 ,  2.887026  ,  3.6766102 ],
         ...,
         [ 2.5870075 ,  1.5659733 ,  3.9213467 ],
         [ 3.08634   ,  3.9047039 ,  2.6152987 ],
         [ 3.3470142 ,  1.4825528 ,  2.1317024 ]],

        [[ 0.08114052,  0.03826797, -0.30205333],


In [16]:
np.equal(expected_field_result, result.field).all()

False