In this file I start by implementing a compressed sensing reconstruction approach. I want to have a well accepted classical method 
as a baseline for my later deep learning models.

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]= '2' #, this way I would choose GPU 3 to do the work

from scipy.io import loadmat
import numpy as np
import matplotlib.pyplot as plt

1. Loading data

In [None]:
# Note that the data is stored in csi = chemical shift imaging, I load consistent data of a total of 5 subjects
mat_data_1 = loadmat('fn_vb_DMI_CRT_P03/CombinedCSI_full_rank.mat')
mat_data_2 = loadmat('fn_vb_DMI_CRT_P04/CombinedCSI_full_rank.mat')
mat_data_3 = loadmat('fn_vb_DMI_CRT_P05/CombinedCSI_full_rank.mat')
mat_data_4 = loadmat('fn_vb_DMI_CRT_P06/CombinedCSI_full_rank.mat')
mat_data_5 = loadmat('fn_vb_DMI_CRT_P07/CombinedCSI_full_rank.mat')

# Inspect the loaded data
csi_1 = mat_data_1['csi']
csi_2 = mat_data_2['csi']
csi_3 = mat_data_3['csi']
csi_4 = mat_data_4['csi']
csi_5 = mat_data_5['csi']

Data_1 = csi_1['Data'][0,0]
Data_2 = csi_2['Data'][0,0]
Data_3 = csi_3['Data'][0,0]
Data_4 = csi_4['Data'][0,0]
Data_5 = csi_5['Data'][0,0]

combined_data = np.stack((Data_1, Data_2, Data_3, Data_4, Data_5), axis=-1)
np.save('combined_data_full_rank', combined_data)

In [None]:
# Assume combined_data contains the stacked data for all 5 datasets
# Shape of combined_data: (x, y, z, time, long_time, dataset)

# Specify fixed indices
time_index = 0  # Fixed time index
long_time_index = 1  # Fixed long_time index

# Define the range of z indices
z_indices = range(0, 21)  # z indices 0 to 20
n_datasets = combined_data.shape[-1]  # Number of datasets (assumed last dimension)

# Determine grid size
n_rows = len(z_indices)  # One row per z index
n_cols = n_datasets  # One column per dataset

# Create a figure
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 2.5 * n_rows))
axes = np.array(axes).reshape(n_rows, n_cols)  # Ensure 2D array of axes

# Loop through z indices and datasets
for i, z in enumerate(z_indices):
    for j in range(n_datasets):
        # Extract the 2D slice for the given z index, time index, and long_time index
        slice_data = combined_data[:, :, z, time_index, long_time_index, j]
        absolute_slice = np.abs(slice_data)
        
        # Plot the slice
        ax = axes[i, j]
        im = ax.imshow(absolute_slice, cmap='viridis', origin='lower')
        if i == 0:  # Add title only for the top row
            ax.set_title(f"Dataset {j+1}", fontsize=10)
        if j == 0:  # Add label only for the first column
            ax.set_ylabel(f"z Index {z}", fontsize=10)
        ax.axis("off")  # Remove axis ticks for clarity
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

# Adjust layout
plt.tight_layout()
plt.show()


The amplitude of the signal decreases of time. I start by visualizing how the overall signal of a slice changes over small time t.

In [None]:
# Visualize the average signal decay over time for all z-slices

abs_data = np.abs(Data)

# Average over the two spatial dimensions (the first two axes)
avg_signal = np.mean(abs_data, axis=(0, 1))

# Define the range of z indices
z_indices = range(0, 21)  # z indices 0 to 20

# Determine grid size for plotting
n_cols = 5  # Number of columns in the grid
n_rows = int(np.ceil(len(z_indices) / n_cols))  # Number of rows

# Create a figure
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 3 * n_rows))
axes = axes.flatten()  # Flatten for easier indexing

# Fixed parameter index (you can loop or iterate over this later if needed)
param_index = 0  # Choose one of the 8 spectral/parameter indices

# Loop through the z indices and plot the signal decay over time
for i, z in enumerate(z_indices):
    # Extract the time-course data for the current z index
    time_course = avg_signal[z, :, param_index]  # Shape: (96,)

    # Plot the signal decay
    ax = axes[i]
    ax.plot(time_course, label=f"z Index {z}")
    ax.set_title(f"z Index {z}")
    ax.set_xlabel("Time Index (t)")
    ax.set_ylabel("Avg Signal Magnitude")
    ax.grid(True)
    ax.legend(loc="upper right", fontsize="small")

# Turn off unused subplots
for j in range(len(z_indices), len(axes)):
    axes[j].axis("off")

# Adjust layout
plt.tight_layout()
plt.show()


2. Next, I visualize estimates of the SNR for each slice of small time t. SNR is estimates by sigma (4 corners 3x3) / average_signal)

In [None]:
SNR = compute_all_snr(Data)

# Visualize the SNR over time for all z-slices

# Define the range of z indices
z_indices = range(0, 21)  # z indices 0 to 20

# Determine grid size for plotting
n_cols = 5  # Number of columns in the grid
n_rows = int(np.ceil(len(z_indices) / n_cols))  # Number of rows

# Create a figure
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 3 * n_rows))
axes = axes.flatten()  # Flatten for easier indexing

# Fixed parameter index (you can loop or iterate over this later if needed)
param_index = 0  # Choose one of the 8 spectral/parameter indices

# Loop through the z indices and plot the SNR over time
for i, z in enumerate(z_indices):
    # Extract the SNR time-course data for the current z index
    snr_time_course = SNR[z, :, param_index]  # Shape: (96,)

    # Plot the SNR time course
    ax = axes[i]
    ax.plot(snr_time_course, label=f"z Index {z}", color="blue")
    ax.set_title(f"SNR for z Index {z}")
    ax.set_xlabel("Time Index (t)")
    ax.set_ylabel("SNR")
    ax.grid(True)
    ax.legend(loc="upper right", fontsize="small")

# Turn off unused subplots
for j in range(len(z_indices), len(axes)):
    axes[j].axis("off")

# Adjust layout
plt.tight_layout()
plt.show()


Next, I look at the spectral representation of the data, to gain even more insight

In [None]:
# Specify the fixed spectral and long_time indices
spectral_index = 50  # Fixed spectral index
long_time_index = 5  # Fixed long_time index

# Define the range of z indices
z_indices = range(0, 21)  # z indices 0 to 20

# Determine grid size for plotting
n_cols = 5  # Number of columns in the grid
n_rows = int(np.ceil(len(z_indices) / n_cols))  # Number of rows

# Create a figure
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 3 * n_rows))
axes = axes.flatten()  # Flatten for easier indexing

# Loop through the z indices and plot each slice
for i, z in enumerate(z_indices):
    # Extract the 2D slice for the current z index
    slice_data = spectral_data[:, :, z, spectral_index, long_time_index]
    absolute_slice = np.abs(slice_data)
    
    # Plot the slice
    ax = axes[i]
    im = ax.imshow(absolute_slice, cmap='viridis', origin='lower')
    ax.set_title(f"z Index {z}")
    ax.axis("off")  # Remove axis ticks for clarity
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

# Turn off unused subplots
for j in range(len(z_indices), len(axes)):
    axes[j].axis("off")

# Adjust layout
plt.tight_layout()
plt.show()

3. A fixed 2D slice at a good z position (10), for varying spectral indices, to check which spectral indices are dominated by noise.
Conclusion: Only indices 44-52 lead to meaningful 2D images. The explanation for this behavior is found in the cell below the next cell

In [None]:
# Define the range of spectral indices to plot
spectral_range = range(30, 70)  # Spectral indices 40 to 60 (inclusive)

# Specify the spatial and long_time indices
z_index = 10  # Fixed z-slice
long_time_index = 7  # Fixed long_time index

# Determine grid size for plotting
n_cols = 5  # Number of columns in the grid
n_rows = int(np.ceil(len(spectral_range) / n_cols))  # Number of rows

# Create a figure
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 3 * n_rows))
axes = axes.flatten()  # Flatten for easier indexing

# Loop through the spectral indices and plot each slice
for i, spectral_index in enumerate(spectral_range):
    # Extract the 2D slice for the current spectral index
    slice_data = spectral_data[:, :, z_index, spectral_index, long_time_index]
    absolute_slice = np.abs(slice_data)
    
    # Plot the slice
    ax = axes[i]
    im = ax.imshow(absolute_slice, cmap='viridis', origin='lower')
    ax.set_title(f"Spectral Index {spectral_index}")
    ax.axis("off")  # Remove axis ticks for clarity
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

# Turn off unused subplots
for j in range(len(spectral_range), len(axes)):
    axes[j].axis("off")

# Adjust layout
plt.tight_layout()
plt.show()

Why do some images look like noise? To understand this its good to visualize the spectrum, an example is given below.
There really is only strong signal close to the big peak of the spectrum, which probably corresponds to deuterium in Water. 
For other frequencies the deuterium signal is so low that noise dominates the signal. 

In [None]:
# Specify the spatial indices, long_time index, and extract the spectral data
# Extract and plot spectra for multiple long_time indices
# Specify spatial indices
x, y, z = 10, 10, 10  # Spatial indices

# Define the range of long_time indices you want to plot (e.g., 0 to 7)
long_time_indices = [0,7]  # This will include 0, 1, 2, ..., 7

# Create the plot
plt.figure(figsize=(10, 6))

# Iterate over the specified range of long_time indices
for t in long_time_indices:
    spectrum = np.abs(spectral_data[x, y, z, :, t])  # Extract spectrum at each long_time index
    plt.plot(range(len(spectrum)), spectrum, marker='o', label=f"Time Index {t}")

# Add labels, title, and legend
plt.title(f"Spectral Evolution at Spatial Index ({x}, {y}, {z}) from Time 0 to 7")
plt.xlabel("Spectral Index")
plt.ylabel("Signal Intensity")
plt.grid(True)
plt.legend()  # Show the legend to differentiate time indices
plt.show()

Conclusion: As a starting point, I only keep imaged corresponding to spectral indices 44-52 and z indices corresponding to 4-15

**Data cleaning:** As shown above, for small t value > 20, there is basically only noise. Here I check what happens if I set everything for t > 20 to 0. I would like to know if the spectral plot becomes less noisy like this.

In [None]:
Cleaned_Data = Data
Cleaned_Data[:, :, :, 21:, :] = 0

Cleaned_Spectral_Data = np.fft.fftshift(np.fft.fft(Cleaned_Data, axis=-2), axes=-2) # Note that this results in the spectral data, where you can visualize images according to the spectrum rather than the absolute signal

x, y, z = 10, 10, 10  # Spatial indices

# Define the range of long_time indices you want to plot (e.g., 0 to 7)
long_time_indices = [0,7]  # This will include 0, 1, 2, ..., 7

# Create the plot
plt.figure(figsize=(10, 6))

# Iterate over the specified range of long_time indices
for t in long_time_indices:
    spectrum = np.abs(Cleaned_Spectral_Data[x, y, z, :, t])  # Extract spectrum at each long_time index
    plt.plot(range(len(spectrum)), spectrum, marker='o', label=f"Time Index {t}")

# Add labels, title, and legend
plt.title(f"Spectral Evolution at Spatial Index ({x}, {y}, {z}) from Time 0 to 7")
plt.xlabel("Spectral Index")
plt.ylabel("Signal Intensity")
plt.grid(True)
plt.legend()  # Show the legend to differentiate time indices
plt.show()