In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

In [3]:
import numpy as np
from h5py import File

from magnet_pinn.preprocessing.preprocessing import GridPreprocessing
from magnet_pinn.data.transforms import Crop, GridPhaseShift, Compose, DefaultTransform, PhaseShift
from magnet_pinn.data.grid import MagnetGridIterator
from magnet_pinn.data.dataitem import DataItem

In [4]:
d = DataItem(
        simulation="",
        input=np.random.rand(3, 20, 20, 20).astype(np.float32),
        field=np.random.rand(2, 2, 3, 20, 20, 20, 8).astype(np.float32),
        subject=np.random.choice([0, 1], size=(20, 20, 20)).astype(np.bool_),
        phase=np.random.rand(8).astype(np.float32),
        mask=np.random.choice([0, 1], size=8).astype(np.bool_),
        coils=np.random.choice([0, 1], size=(20, 20, 20, 8)).astype(np.float32),
        dtype="float32",
        truncation_coefficients=np.ones(3, dtype=np.float32)
    )

In [5]:
aug = PhaseShift(num_coils=8, sampling_method="uniform")

In [6]:
result = aug(d)

In [7]:
im = d.field[:, 1, :, :, :, :, :]
im.shape

(2, 3, 20, 20, 20, 8)

In [8]:
real = d.field[:, 0, :, :, :, :, :]
real.shape

(2, 3, 20, 20, 20, 8)

In [9]:
fields_complex = real + 1j * im

In [10]:
coefs = np.exp(1j * result.phase) * result.mask

In [11]:
coefs.shape

(8,)

In [12]:
res = fields_complex * coefs
res

array([[[[[[ 0.00000000e+00+0.00000000e+00j,
            -8.31457496e-01+3.86249460e-02j,
            -0.00000000e+00+0.00000000e+00j, ...,
             0.00000000e+00+0.00000000e+00j,
             4.50602882e-02+1.08338463e+00j,
            -0.00000000e+00+0.00000000e+00j],
           [ 0.00000000e+00+0.00000000e+00j,
            -7.49886692e-01+5.30582190e-01j,
            -0.00000000e+00+0.00000000e+00j, ...,
             0.00000000e+00+0.00000000e+00j,
            -4.26522195e-01+8.40288997e-01j,
            -0.00000000e+00+0.00000000e+00j],
           [ 0.00000000e+00+0.00000000e+00j,
            -3.97057235e-01+7.01594651e-01j,
            -0.00000000e+00+0.00000000e+00j, ...,
             0.00000000e+00+0.00000000e+00j,
            -2.10207775e-02+7.56656468e-01j,
            -0.00000000e+00+0.00000000e+00j],
           ...,
           [ 0.00000000e+00+0.00000000e+00j,
            -9.57924575e-02+2.83134039e-02j,
            -0.00000000e+00+0.00000000e+00j, ...,
             0.0

In [29]:
reshaped = np.stack([res.real, res.imag], axis=1)
reshaped.shape

(2, 2, 3, 20, 20, 20, 8)

In [30]:
resulted = np.sum(reshaped, axis=-1)
resulted.shape

(2, 2, 3, 20, 20, 20)

In [31]:
resulted

array([[[[[[ 1.35896683e+00,  2.38721085e+00,  1.58780551e+00, ...,
             7.40999222e-01,  2.47242451e-01,  1.65906906e+00],
           [ 1.58591211e-01,  6.29201531e-01,  2.53426313e-01, ...,
             3.15751076e+00,  2.40315080e-01,  6.78167582e-01],
           [ 1.03546250e+00,  5.09796262e-01,  1.10754967e-02, ...,
             1.04615843e+00,  1.19260335e+00,  2.58565664e+00],
           ...,
           [ 1.84185326e+00,  1.17613864e+00,  1.75592351e+00, ...,
             9.66786981e-01,  1.01889133e+00,  2.44708538e+00],
           [ 1.88603723e+00,  2.16998458e+00,  7.73967147e-01, ...,
             3.00950599e+00,  2.28399706e+00,  1.89626181e+00],
           [-7.06765652e-02,  1.46073794e+00,  2.19531918e+00, ...,
             9.07467484e-01,  1.94543147e+00,  1.27283871e+00]],

          [[ 1.47690368e+00,  8.07802439e-01,  1.53692794e+00, ...,
             7.60617256e-01,  1.56508064e+00,  2.23195243e+00],
           [ 1.68592751e+00,  1.86014462e+00,  1.98342419e

In [32]:
result.field

array([[[[[[ 1.35896671e+00,  2.38721085e+00,  1.58780563e+00, ...,
             7.40999162e-01,  2.47242451e-01,  1.65906894e+00],
           [ 1.58591181e-01,  6.29201353e-01,  2.53426373e-01, ...,
             3.15751076e+00,  2.40315080e-01,  6.78167820e-01],
           [ 1.03546238e+00,  5.09796262e-01,  1.10756159e-02, ...,
             1.04615843e+00,  1.19260335e+00,  2.58565664e+00],
           ...,
           [ 1.84185326e+00,  1.17613864e+00,  1.75592327e+00, ...,
             9.66787100e-01,  1.01889157e+00,  2.44708562e+00],
           [ 1.88603723e+00,  2.16998458e+00,  7.73967326e-01, ...,
             3.00950575e+00,  2.28399682e+00,  1.89626169e+00],
           [-7.06766248e-02,  1.46073794e+00,  2.19531918e+00, ...,
             9.07467484e-01,  1.94543135e+00,  1.27283883e+00]],

          [[ 1.47690380e+00,  8.07802200e-01,  1.53692794e+00, ...,
             7.60617316e-01,  1.56508088e+00,  2.23195267e+00],
           [ 1.68592739e+00,  1.86014473e+00,  1.98342443e

In [35]:
np.isclose(resulted, result.field).all()

False

In [41]:
resulting

array([[[[[[ 5.02911210e-01,  8.80008996e-01,  6.29223108e-01, ...,
             2.56481916e-01,  1.47069097e-01,  7.92935967e-01],
           [ 1.14489269e+00,  2.85320222e-01,  8.22915137e-01, ...,
             6.74429774e-01,  9.84184146e-02, -3.22465032e-01],
           [ 3.99248123e-01,  6.85926080e-01,  8.95432651e-01, ...,
            -6.31891489e-02,  2.07548261e-01,  1.99757054e-01],
           ...,
           [ 5.04164755e-01,  7.62968183e-01,  2.30382591e-01, ...,
            -8.38835537e-02,  6.15491390e-01,  8.97669375e-01],
           [-4.25584584e-01,  5.42343929e-02,  1.10900328e-01, ...,
            -3.30843031e-03,  1.11077356e+00,  5.56092441e-01],
           [-1.26004398e-01,  1.00203466e+00,  4.02090698e-02, ...,
            -4.05320346e-01,  4.71729696e-01,  3.27121973e-01]],

          [[ 1.44857585e-01,  4.89838868e-01,  4.45851594e-01, ...,
            -3.05227339e-01,  5.78607202e-01,  1.42177030e-01],
           [-2.43228823e-01,  7.78244510e-02,  1.44798249e

In [13]:
coils_resulted = d.coils * np.exp(1j * result.phase) * result.mask

In [14]:
coils_resulted.shape

(20, 20, 20, 8)

In [17]:
summed_coils = np.sum(np.stack([coils_resulted.real, coils_resulted.imag], axis=0), axis=-1)

In [20]:
summed_coils

array([[[[-1.0193837 , -0.05306143, -0.01626603, ..., -0.01626603,
           0.        , -0.9825883 ],
         [-0.25378048,  0.        , -0.25378048, ..., -0.25378048,
           0.71254176, -0.7656032 ],
         [-0.96632224, -1.7481915 ,  0.6962757 , ..., -0.01626603,
           0.6962757 , -0.25378048],
         ...,
         [-0.78186923, -1.7481915 , -1.7481915 , ..., -0.06932747,
          -0.05306143, -0.9825883 ],
         [-0.06932747, -0.7656032 , -0.9825883 , ..., -1.0356498 ,
          -0.9825883 , -0.9825883 ],
         [ 0.71254176, -0.9825883 , -0.01626603, ..., -0.06932747,
          -0.06932747, -0.2700465 ]],

        [[ 0.6962757 , -0.2700465 , -1.0193837 , ...,  0.        ,
          -0.2700465 , -1.7481915 ],
         [-0.78186923, -0.7656032 , -0.2700465 , ...,  0.71254176,
           0.71254176, -0.25378048],
         [-1.0356498 ,  0.71254176, -0.06932747, ..., -0.05306143,
          -0.9825883 ,  0.71254176],
         ...,
         [-0.06932747, -1.7319255 

In [21]:
result.coils

array([[[[-1.0193837 , -0.05306137, -0.01626603, ..., -0.01626603,
           0.        , -0.9825883 ],
         [-0.25378042,  0.        , -0.25378042, ..., -0.25378042,
           0.7125418 , -0.7656032 ],
         [-0.96632224, -1.7481915 ,  0.6962758 , ..., -0.01626603,
           0.6962758 , -0.25378042],
         ...,
         [-0.78186923, -1.7481915 , -1.7481915 , ..., -0.0693274 ,
          -0.05306137, -0.9825883 ],
         [-0.0693274 , -0.7656032 , -0.9825883 , ..., -1.0356497 ,
          -0.9825883 , -0.9825883 ],
         [ 0.7125418 , -0.9825883 , -0.01626603, ..., -0.0693274 ,
          -0.0693274 , -0.27004647]],

        [[ 0.6962758 , -0.27004647, -1.0193837 , ...,  0.        ,
          -0.27004647, -1.7481915 ],
         [-0.78186923, -0.7656032 , -0.27004647, ...,  0.7125418 ,
           0.7125418 , -0.25378042],
         [-1.0356497 ,  0.7125418 , -0.0693274 , ..., -0.05306137,
          -0.9825883 ,  0.7125418 ],
         ...,
         [-0.0693274 , -1.7319255 

In [20]:
result.input.shape

(3, 20, 20, 20)

In [11]:
np.random.binomial(10, 0.5)

4

In [4]:
np.random.seed(42)

In [5]:
gen = np.random.default_rng()

In [6]:
gen.random(size=(2, 2))

array([[0.17238913, 0.17864619],
       [0.11500042, 0.28666673]])

In [7]:
gen.choice([0, 1], size=(2, 2))

array([[1, 1],
       [0, 0]])

In [8]:
one_more_gen = np.random.default_rng()

In [9]:
one_more_gen.random(size=(2, 2))

array([[0.44955109, 0.87970409],
       [0.65075623, 0.31944685]])

In [10]:
one_more_gen.choice([0, 1], size=(2, 2))

array([[0, 1],
       [1, 0]])

In [10]:
np.random.rand(2, 2)

array([[0.4236548 , 0.64589411],
       [0.43758721, 0.891773  ]])

In [29]:
aug = PhaseShift(
    num_coils=8
)

In [30]:
iter = MagnetGridIterator(
    "../data/processed/grid_voxel_size_4_data_type_float32",
    transforms=aug,
    num_samples=1
)

In [66]:
import h5py
from typing import List, Tuple
import numpy.typing as npt

In [69]:
def read_field(f: h5py.File, field_key: str) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
            field_val = f[field_key][:]
            if field_val.dtype.names is None:
                return field_val.real, field_val.imag
            return field_val["re"], field_val["im"]
        
with h5py.File("../data/processed/grid_voxel_size_4_data_type_float32/simulations/children_0_tubes_0_id_28641.h5") as f:
    re_efield, im_efield = read_field(f, "efield")
    re_hfield, im_hfield = read_field(f, "hfield")
        
res = np.stack([np.stack([re_efield, im_efield], axis=0), np.stack([re_hfield, im_hfield], axis=0)], axis=0)

In [70]:
res.shape

(2, 2, 3, 121, 111, 126, 8)

In [31]:
for i in iter:
    break

In [11]:
i.keys()

dict_keys(['simulation', 'input', 'field', 'subject', 'positions', 'phase', 'mask', 'coils', 'dtype', 'truncation_coefficients'])

In [12]:
i["simulation"]

'children_0_tubes_7_id_25481'

In [14]:
i["input"].shape

(3, 121, 111, 126)

In [15]:
i["field"].shape

(2, 2, 3, 121, 111, 126)

In [17]:
i["subject"].shape

(121, 111, 126)

In [18]:
i["positions"]

[]

In [20]:
i["phase"].shape

(8,)

In [21]:
i["mask"].shape

(8,)

In [23]:
i["coils"].shape

(2, 121, 111, 126)

In [24]:
i["dtype"]

'float32'

In [25]:
i["truncation_coefficients"]

array([1., 1., 1.], dtype=float32)

In [35]:
with File("../data/processed/grid_voxel_size_4_data_type_float32/simulations/children_0_tubes_0_id_28641.h5") as f:
    print(f.keys())
    e = f["efield"][:]

<KeysViewHDF5 ['efield', 'hfield', 'input', 'subject']>


In [36]:
e.shape

(3, 121, 111, 126, 8)

In [43]:
phase = np.random.uniform(0, 2*np.pi, 8).astype(np.float32)
phase.shape

(8,)

In [42]:
mask = np.zeros(8, dtype=bool)
mask.shape

(8,)

In [50]:
re_phase = np.cos(phase) * mask

In [51]:
im_phase = np.sin(phase) * mask

In [52]:
coef_real = np.stack((re_phase, -im_phase), axis=0)
coef_real.shape

(2, 8)

In [53]:
coef_im = np.stack((im_phase, re_phase), axis=0)
coef_im.shape

(2, 8)

In [58]:
coeffs = np.stack((coef_real, coef_im), axis=0)
coeffs.shape

(2, 2, 8)

In [59]:
import einops

In [60]:
new_coefs = einops.repeat(coeffs, 'reimout reim coils -> hf reimout reim coils', hf=2)
new_coefs.shape

(2, 2, 2, 8)

In [61]:
e.shape

(3, 121, 111, 126, 8)

In [73]:
einops.einsum(
    res, new_coefs,
    "hf reim fieldxyz ... coils, hf reimout reim coils -> hf reim fieldxyz ...",
).shape

(2, 2, 3, 121, 111, 126)