# Explore 2D+t Heart MRI dataset

In [None]:
%load_ext jupyter_black

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

from scipy.io import loadmat
from tempfile import NamedTemporaryFile

sys.path.append("../")  # append repo root dir

from utils.plotting import display_video_grid

## 1. Load data

In [None]:
# load images from mat file
data = loadmat("../data/raw/2dt_heart.mat")

In [None]:
# parse the images
videos = data["imgs"]
videos = np.moveaxis(videos, (2, 3), (1, 0))
videos.shape

## 2. Visualize the videos

In [None]:
# Display the video grid
sampled_indices = np.random.choice(videos.shape[0], 16, replace=False)
display_video_grid(videos[sampled_indices], grid_size=(4, 4), figsize=(24, 24))

## 3. Fourier Transform the images

In [None]:
videos_fft = np.fft.fftshift(np.fft.fft2(videos, axes=(-2, -1)), axes=(-2, -1))

# Display the video grid
display_video_grid(
    np.abs(videos_fft[sampled_indices]), grid_size=(4, 4), figsize=(24, 24), norm="log"
)

## 4. Randomly mask 75% of rows

In [None]:
factor = 0.5
# Determine the number of rows and the middle row
num_rows = videos_fft.shape[2]
middle_row = num_rows // 2

# Create an array of row indices excluding the middle row
row_indices = np.delete(np.arange(num_rows), middle_row)

num_zero_rows = int(factor * num_rows)

mask = np.zeros_like(videos_fft, dtype=bool)

# Randomly select rows to zero out from 'row_indices'
for i in range(videos_fft.shape[0]):
    for j in range(videos_fft.shape[1]):
        selected_rows = np.random.choice(row_indices, num_zero_rows, replace=False)
        mask[i, j, selected_rows, :] = True

# Apply the mask to 'videos_fft'
videos_fft_masked = videos_fft.copy()
videos_fft_masked[mask] = 1

# Display the video grid
display_video_grid(
    np.abs(videos_fft_masked[sampled_indices]),
    grid_size=(4, 4),
    figsize=(24, 24),
    norm="log",
)

## 5. Inverse FFT the images

In [None]:
videos_masked = np.fft.ifft2(
    np.fft.ifftshift(videos_fft_masked, axes=(-2, -1)), axes=(-2, -1)
)

# Display the video grid
display_video_grid(
    np.abs(videos_masked[sampled_indices]), grid_size=(4, 4), figsize=(24, 24)
)