# Explore 2D+t Heart MRI dataset

In [None]:
%load_ext jupyter_black


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

from scipy.io import loadmat
from tempfile import NamedTemporaryFile

from projectB.utils.plotting import PlotUtils

## 1. Load data

In [None]:
# load images from mat file
data = loadmat("../data/raw/2dt_heart.mat")

In [None]:
# parse the images
videos = data["imgs"]
videos = np.moveaxis(videos, (2, 3), (1, 0))
videos.shape

## 2. Visualize the videos

In [None]:
# Display the video grid
sampled_indices = np.random.choice(videos.shape[0], 16, replace=False)
PlotUtils.display_video_grid(
    videos[sampled_indices], grid_size=(4, 4), figsize=(24, 24)
)

## 3. Fourier Transform the images

In [None]:
from projectB.data_handling.transforms.fft import FFT2D
import torch

# Apply FFT to the videos
fft = FFT2D()
videos_fft = fft(torch.tensor(videos))

# Display the video grid
PlotUtils.display_video_grid(
    np.abs(videos_fft[sampled_indices]), grid_size=(4, 4), figsize=(24, 24), norm="log"
)

## 4. Randomly mask 75% of rows

In [None]:
from projectB.data_handling.transforms.undersampling import UniformUndersampler

undersampler = UniformUndersampler(factor=0.5, hw_center=2, seed=42)

videos_fft_masked = undersampler.forward(videos_fft)

In [None]:
# Display the video grid
PlotUtils.display_video_grid(
    np.abs(videos_fft_masked[sampled_indices]),
    grid_size=(4, 4),
    figsize=(24, 24),
    norm="log",
)

## 5. Inverse FFT the images

In [None]:
videos_masked = np.fft.ifft2(
    np.fft.ifftshift(videos_fft_masked, axes=(-2, -1)), axes=(-2, -1)
)

# Display the video grid
PlotUtils.display_video_grid(
    np.abs(videos_masked[sampled_indices]), grid_size=(4, 4), figsize=(24, 24)
)

## 6. Do it all at once

In [None]:
from torchvision.transforms import v2

transforms = v2.Compose(
    [
        FFT2D(),
        UniformUndersampler(factor=0.5, hw_center=2, seed=42),
    ]
)

videos_fft_masked = transforms(torch.tensor(videos))

# Display the video grid
PlotUtils.display_video_grid(
    np.abs(videos_fft_masked[sampled_indices]),
    grid_size=(4, 4),
    figsize=(24, 24),
    norm="log",
)