In [None]:
pip install numpy torch scikit-image h5py pywt matplotlib

In [1]:
import os
import glob
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
import pywt
from numpy.fft import fft2, ifft2, fftshift, ifftshift
from skimage.transform import resize
import time
import gc

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- Mini U-Net CNN Architecture ---
class MiniUNet(nn.Module):
    def __init__(self, img_size=256):
        super(MiniUNet, self).__init__()
        # Encoder
        self.enc1 = nn.Conv2d(2, 32, 3, padding=1)  # Input: real+imag
        self.enc2 = nn.Conv2d(32, 64, 3, padding=1)
        # Bottleneck
        self.bottleneck = nn.Conv2d(64, 128, 3, padding=1)
        # Decoder
        self.dec2 = nn.Conv2d(192, 64, 3, padding=1)  # Skip connection
        self.dec1 = nn.Conv2d(96, 32, 3, padding=1)
        self.out = nn.Conv2d(32, 1, 3, padding=1)  # Output: real image
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        # Encoder
        e1 = self.relu(self.enc1(x))  # [B, 32, img_size, img_size]
        e2 = self.relu(self.enc2(self.pool(e1)))  # [B, 64, img_size/2, img_size/2]
        # Bottleneck
        b = self.relu(self.bottleneck(self.pool(e2)))  # [B, 128, img_size/4, img_size/4]
        # Decoder with skip connections
        d2 = self.relu(self.dec2(torch.cat([self.upsample(b), e2], dim=1)))  # [B, 64, img_size/2]
        d1 = self.relu(self.dec1(torch.cat([self.upsample(d2), e1], dim=1)))  # [B, 32, img_size]
        out = self.out(d1)  # [B, 1, img_size, img_size]
        return out

# --- Dataset Loading ---
def load_fastmri_slices(base_path, num_subjects=10, slices_per_subject=5):
    h5_files = glob.glob(os.path.join(base_path, '*.h5'))
    if not h5_files:
        print(f"ERROR: No .h5 files found in {base_path}")
        return [], []

    num_subjects = min(num_subjects, len(h5_files))
    h5_files = h5_files[:num_subjects]
    images = []
    subjects_slices = []

    for h5_path in h5_files:
        try:
            with h5py.File(h5_path, 'r') as f:
                kspace = f['kspace'][()]  # Shape: (slices, coils, height, width)
                num_slices = kspace.shape[0]
                central_idx = num_slices // 2
                slice_indices = range(central_idx - slices_per_subject//2, central_idx + slices_per_subject//2 + 1)
                subject_id = os.path.basename(h5_path).split('.')[0]

                for slice_idx in slice_indices:
                    if slice_idx < 0 or slice_idx >= num_slices:
                        continue
                    kspace_slice = kspace[slice_idx]  # Shape: (coils, height, width)
                    img_coils = ifft2(ifftshift(kspace_slice, axes=(1, 2)), axes=(1, 2))
                    img_rss = np.sqrt(np.sum(np.abs(img_coils)**2, axis=0))
                    img_rss = resize(img_rss, (256, 256), anti_aliasing=True).astype(np.float32)
                    if np.max(img_rss) > np.min(img_rss):
                        img_rss = (img_rss - np.min(img_rss)) / (np.max(img_rss) - np.min(img_rss))
                    else:
                        img_rss = np.zeros_like(img_rss)
                    images.append(img_rss)
                    subjects_slices.append((subject_id, slice_idx))
        except Exception as e:
            print(f"ERROR loading {h5_path}: {e}")
            continue

    return np.array(images), subjects_slices

# --- Mask and Reconstruction Functions ---
def create_variable_density_mask(shape, acceleration_factor, center_fraction=0.08, poly_degree=2, seed=None):
    if seed is not None:
        np.random.seed(seed)
    rows, cols = shape
    center_x, center_y = cols // 2, rows // 2
    x_coords = np.abs(np.arange(cols) - center_x)
    y_coords = np.abs(np.arange(rows) - center_y)
    dist_x, dist_y = np.meshgrid(x_coords, y_coords)
    norm_dist_x = dist_x / (np.max(dist_x) if np.max(dist_x) > 0 else 1)
    norm_dist_y = dist_y / (np.max(dist_y) if np.max(dist_y) > 0 else 1)
    pdf = (1 - norm_dist_x**poly_degree) * (1 - norm_dist_y**poly_degree)
    pdf = np.clip(pdf, 0, 1)
    target_samples = int(np.prod(shape) / acceleration_factor)
    flat_pdf = pdf.flatten()
    sorted_indices = np.argsort(-flat_pdf)
    mask = np.zeros(shape, dtype=bool).flatten()
    mask[sorted_indices[:target_samples]] = True
    mask = mask.reshape(shape)
    center_rows_abs = int(shape[0] * center_fraction)
    center_cols_abs = int(shape[1] * center_fraction)
    r_start, r_end = shape[0]//2 - center_rows_abs//2, shape[0]//2 + center_rows_abs//2
    c_start, c_end = shape[1]//2 - center_cols_abs//2, shape[1]//2 + center_cols_abs//2
    mask[r_start:r_end, c_start:c_end] = True
    actual_accel = np.prod(shape) / np.sum(mask)
    print(f"Variable Density Mask: Target R={acceleration_factor}, Actual R={actual_accel:.2f}")
    return mask.astype(np.float32)

def apply_mask(img, mask):
    kspace = fftshift(fft2(img))
    undersampled_kspace = kspace * mask
    return undersampled_kspace, kspace

def zero_filled_reconstruction(undersampled_kspace):
    img_zf = np.abs(ifft2(ifftshift(undersampled_kspace)))
    return np.clip(img_zf, 0, 1)

def prepare_cnn_input(undersampled_kspace):
    zf_img = ifft2(ifftshift(undersampled_kspace))
    zf_real = np.real(zf_img)
    zf_imag = np.imag(zf_img)
    input_tensor = np.stack([zf_real, zf_imag], axis=0)  # [2, 256, 256]
    return torch.tensor(input_tensor, dtype=torch.float32).unsqueeze(0)  # [1, 2, 256, 256]

# --- ISTA Functions (from provided script) ---
def soft_threshold(x, threshold):
    return np.sign(x) * np.maximum(np.abs(x) - threshold, 0)

def wavelet_forward(image, wavelet='db4', level=3):
    coeffs = pywt.wavedec2(image, wavelet=wavelet, level=level)
    arr, coeff_slices = pywt.coeffs_to_array(coeffs)
    return arr, coeff_slices

def wavelet_inverse(arr, coeff_slices, wavelet='db4'):
    coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices, output_format='wavedec2')
    return pywt.waverec2(coeffs_from_arr, wavelet=wavelet)

def ista_wavelet_cs(k_space_undersampled, mask, initial_image, n_iters, lambda_val, ground_truth_for_psnr):
    x_recon = initial_image.copy().astype(np.complex128)
    k_space_undersampled = k_space_undersampled.astype(np.complex128)
    step_size = 1.0
    for i in range(n_iters):
        current_k_space = fftshift(fft2(x_recon))
        k_space_error = (current_k_space * mask) - k_space_undersampled
        grad_data_term = ifft2(ifftshift(k_space_error * mask))
        x_intermediate = x_recon - step_size * grad_data_term
        x_intermediate_real = np.real(x_intermediate)
        coeffs_arr, coeff_slices = wavelet_forward(x_intermediate_real, wavelet='db4', level=3)
        threshold = lambda_val * step_size
        coeffs_list_form = pywt.wavedec2(x_intermediate_real, wavelet='db4', level=3)
        approx_coeffs_size = coeffs_list_form[0].size
        coeffs_arr_thresh = coeffs_arr.copy()
        coeffs_arr_thresh[approx_coeffs_size:] = soft_threshold(coeffs_arr[approx_coeffs_size:], threshold)
        x_reconstructed_real = wavelet_inverse(coeffs_arr_thresh, coeff_slices, wavelet='db4')
        x_recon = x_reconstructed_real.astype(np.complex128)
    return np.clip(np.real(x_recon), 0, 1)

# --- Dataset Class ---
class FastMRIDataset(torch.utils.data.Dataset):
    def __init__(self, images, mask):
        self.images = images  # [num_images, 256, 256]
        self.mask = mask  # [256, 256]
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = self.images[idx]
        undersampled_kspace, _ = apply_mask(img, self.mask)
        input_tensor = prepare_cnn_input(undersampled_kspace)
        target = torch.tensor(img[None], dtype=torch.float32)  # [1, 256, 256]
        return input_tensor.squeeze(0), target  # [2, 256, 256], [1, 256, 256]

# --- Training Function ---
def train_cnn(model, dataset, epochs=50, batch_size=4, lr=0.001):
    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)  # [B, 2, 256, 256], [B, 1, 256, 256]
            optimizer.zero_grad()
            outputs = model(inputs)  # [B, 1, 256, 256]
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {total_loss / len(dataloader):.4f}')
    return model

# --- Evaluation Function ---
def evaluate_reconstruction(img_recon, img_ref):
    img_recon = np.clip(img_recon, 0, 1)
    img_ref = np.clip(img_ref, 0, 1)
    psnr_val = psnr(img_ref, img_recon, data_range=1.0)
    ssim_val = ssim(img_ref, img_recon, data_range=1.0, channel_axis=None)
    return psnr_val, ssim_val

# --- Main Function ---
def main():
    # Define parameters
    base_path = '/kaggle/input/fastmri-brain-multicoil'  # Replace with local path if not on Kaggle
    acceleration_factor = 4
    optimal_lambda = 0.0001
    optimal_iters = 10

    # Load dataset
    images, subjects_slices = load_fastmri_slices(base_path, num_subjects=10, slices_per_subject=5)
    if len(images) == 0:
        raise RuntimeError("No valid images loaded. Check dataset path.")

    # Create mask
    mask = create_variable_density_mask((256, 256), acceleration_factor, seed=0)

    # Split dataset (8 subjects for training, 2 for testing)
    train_images = images[:40]  # 8 subjects * 5 slices
    test_images = images[40:]   # 2 subjects * 5 slices
    test_subjects_slices = subjects_slices[40:]

    # Create training dataset
    train_dataset = FastMRIDataset(train_images, mask)

    # Initialize and train CNN
    model = MiniUNet(img_size=256)
    start_time = time.time()
    trained_model = train_cnn(model, train_dataset, epochs=50, batch_size=4)
    cnn_train_time = time.time() - start_time

    # Evaluate on test slices
    results = {'ZF': [], 'ISTA': [], 'CNN': []}
    for idx, (test_img, (subject_id, slice_idx)) in enumerate(zip(test_images, test_subjects_slices)):
        print(f'\nEvaluating {subject_id}, Slice {slice_idx}')
        undersampled_kspace, _ = apply_mask(test_img, mask)

        # Zero-Filled
        start_time = time.time()
        img_zf = zero_filled_reconstruction(undersampled_kspace)
        zf_time = time.time() - start_time
        zf_psnr, zf_ssim = evaluate_reconstruction(img_zf, test_img)
        results['ZF'].append((zf_psnr, zf_ssim, zf_time))
        print(f'Zero-Filled: PSNR={zf_psnr:.3f}, SSIM={zf_ssim:.4f}, Time={zf_time:.2f}s')

        # ISTA
        start_time = time.time()
        img_ista = ista_wavelet_cs(undersampled_kspace, mask, img_zf.copy(), optimal_iters, optimal_lambda, test_img)
        ista_time = time.time() - start_time
        ista_psnr, ista_ssim = evaluate_reconstruction(img_ista, test_img)
        results['ISTA'].append((ista_psnr, ista_ssim, ista_time))
        print(f'ISTA: PSNR={ista_psnr:.3f}, SSIM={ista_ssim:.4f}, Time={ista_time:.2f}s')

        # CNN
        start_time = time.time()
        input_tensor = prepare_cnn_input(undersampled_kspace)
        with torch.no_grad():
            output = trained_model(input_tensor.to(device))
            img_cnn = output.squeeze().cpu().numpy()  # [256, 256]
        cnn_inf_time = time.time() - start_time
        cnn_psnr, cnn_ssim = evaluate_reconstruction(img_cnn, test_img)
        results['CNN'].append((cnn_psnr, cnn_ssim, cnn_inf_time))
        print(f'CNN: PSNR={cnn_psnr:.3f}, SSIM={cnn_ssim:.4f}, Inference Time={cnn_inf_time:.2f}s')

        # Visualize
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 4, 1)
        plt.imshow(test_img, cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')
        plt.subplot(1, 4, 2)
        plt.imshow(img_zf, cmap='gray')
        plt.title(f'ZF\nPSNR: {zf_psnr:.2f}\nSSIM: {zf_ssim:.4f}')
        plt.axis('off')
        plt.subplot(1, 4, 3)
        plt.imshow(img_ista, cmap='gray')
        plt.title(f'ISTA\nPSNR: {ista_psnr:.2f}\nSSIM: {ista_ssim:.4f}')
        plt.axis('off')
        plt.subplot(1, 4, 4)
        plt.imshow(img_cnn, cmap='gray')
        plt.title(f'CNN\nPSNR: {cnn_psnr:.2f}\nSSIM: {cnn_ssim:.4f}')
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(f'result_{subject_id}_slice_{slice_idx}.png', dpi=300)
        plt.close()
        gc.collect()

    # Summarize results
    print("\n--- Summary of Results (R=4) ---")
    print(f"{'Method':<10} | {'Mean PSNR':<12} | {'Mean SSIM':<12} | {'Mean Time (s)':<15}")
    print("-" * 50)
    for method in results:
        psnr_vals, ssim_vals, times = zip(*results[method])
        mean_psnr = np.mean(psnr_vals)
        mean_ssim = np.mean(ssim_vals)
        mean_time = np.mean(times)
        print(f"{method:<10} | {mean_psnr:.2f} | {mean_ssim:.4f} | {mean_time:.2f}")
    print(f'\nCNN Total Time (Train+Inf): ~{cnn_train_time + np.mean([r[2] for r in results["CNN"]]):.2f}s')

if __name__ == '__main__':
    main()

Using device: cpu
ERROR loading /kaggle/input/fastmri-brain-multicoil/file_brain_AXT2_210_2100179.h5: Unable to synchronously open file (truncated file: eof = 201064448, sblock->base_addr = 0, stored_eof = 778585656)
Variable Density Mask: Target R=4, Actual R=4.00
Epoch 0, Loss: 0.0131
Epoch 10, Loss: 0.0001
Epoch 20, Loss: 0.0001
Epoch 30, Loss: 0.0000
Epoch 40, Loss: 0.0000

Evaluating file_brain_AXT2_207_2070286, Slice 6
Zero-Filled: PSNR=44.318, SSIM=0.9899, Time=0.00s
ISTA: PSNR=44.327, SSIM=0.9899, Time=0.13s
CNN: PSNR=43.906, SSIM=0.9892, Inference Time=0.18s

Evaluating file_brain_AXT2_207_2070286, Slice 7
Zero-Filled: PSNR=44.738, SSIM=0.9910, Time=0.00s
ISTA: PSNR=44.740, SSIM=0.9910, Time=0.10s
CNN: PSNR=44.214, SSIM=0.9902, Inference Time=0.12s

Evaluating file_brain_AXT2_207_2070286, Slice 8
Zero-Filled: PSNR=45.511, SSIM=0.9927, Time=0.00s
ISTA: PSNR=45.511, SSIM=0.9927, Time=0.11s
CNN: PSNR=44.931, SSIM=0.9916, Inference Time=0.13s

Evaluating file_brain_AXT2_207_207028