In [1]:
from Utils2.Dataset import Unrolled_Dataset
from Utils2.Unrolled_2iteration import Physics
from Utils2.Unrolled_2iteration import SamplingFunction
from Utils2.Unrolled_2iteration import UnrolledReconstructor
from Utils2.Model import TinyUNET

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
from monai.metrics.regression import SSIMMetric
import matplotlib.pyplot as plt
import numpy as np
import os

### 1. Dataset and DataLoader

In [2]:
data_path = '/shared/BIOE486/SP25/users/jgarca2/Dataset/multicoil_val_prepro_valid_dataset'
dataset = Unrolled_Dataset(unroll_root_dir=data_path, transform=True)
train_data = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=16, pin_memory=True)
for i, (zf, us, cs, gt) in enumerate(train_data):
    print(f"Batch {i}:")
    print(f"  ZF shape: {zf.shape}")  # (B, 1, 320, 320)
    print(f"  US shape: {us.shape}")  # (B, 15, 640, 115)
    print(f"  CS shape: {cs.shape}")  # (B, 15, 320, 320)
    print(f"  GT shape: {gt.shape}")  # (B, 1, 320, 320)
    break

Found files:
  Zero-filled:         6661
  Coil sensitivity:    6661
  Undersampled:        6661
  Ground truth (Cs):   6661


KeyboardInterrupt: 

### 2. Physics

In [None]:
sampler = SamplingFunction()
physics = Physics(alpha=0.1, sampler=sampler)
W_e = physics._compute_W_e(us, cs)
S = physics._compute_S(zf, cs)
input = physics._final_sum(S,W_e)
print(W_e.shape, S.shape, input.shape)

### 3. Deep learning model (UNETxComplex)

In [None]:
log_dir="/shared/BIOE486/SP25/users/jgarca2/Experiments/Graph"
writer = SummaryWriter(log_dir=log_dir)
# Data
real = torch.randn(1, 1, 320, 320)
imag = torch.randn(1, 1, 320, 320)
model_input = torch.complex(real, imag)
model = TinyUNET()
# Log the graph
writer.add_graph(model, model_input)
writer.close()

### 4. Unrroll physics and DL model

In [None]:
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

# Sent the data to device
model1, model2 = TinyUNET().to(device), TinyUNET().to(device)
zf, us, cs = zf.to(device), us.to(device), cs.to(device)

# Initialize the unrolled reconstructor
reconstructor = UnrolledReconstructor(model1, model2).to(device)

# Forward pass
output = reconstructor(zf, us, cs)
print(f"Output shape: {output.shape}")  # Expected: [32, 1, 320, 320]


### 6. Training loop

In [None]:
class LogMSELoss(nn.Module):
    def __init__(self, eps=1e-5):
        super().__init__()
        self.eps = eps

    def forward(self, pred, target):
        pred_log = torch.log(torch.abs(pred) + self.eps)
        target_log = torch.log(torch.abs(target) + self.eps)
        return torch.mean((pred_log - target_log) ** 2)

# Loss & Metrics
criterion = LogMSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
ssim_metric = SSIMMetric(spatial_dims=2)


model1, model2 = TinyUNET().to(device), TinyUNET().to(device)
reconstructor = UnrolledReconstructor(model1, model2).to(device)

for epoch in range(1, 10):
    reconstructor.train()
    train_loss = 0
    ssim_total = 0
    ssim_count = 0

    for i, (zf, us, cs,  gt) in enumerate(train_data):
        print(f"Epoch {epoch}, Batch {i+1}", end='\r')

        # Move to device
        zf, us, cs, gt = zf.to(device), us.to(device), cs.to(device), gt.to(device)

        # --- Normalize inputs to match GT scale ---
        gt_max = gt.abs().amax(dim=(-1, -2), keepdim=True)
        zf = zf / (gt_max + 1e-8)
        cs = cs / (gt_max + 1e-8)

        # Forward
        pred = reconstructor(zf, us, cs)

        # Loss on log-magnitude or raw magnitude
        loss = criterion(pred, gt)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * zf.size(0)

        # --- SSIM ---
        # Convert to magnitude
        pred_mag = torch.abs(pred)
        gt_mag = torch.abs(gt)

        # Normalize to [0, 1] (optional but helps SSIM stability)
        pred_mag = pred_mag / (pred_mag.max() + 1e-8)
        gt_mag = gt_mag / (gt_mag.max() + 1e-8)

        ssim_score = ssim_metric(pred_mag, gt_mag)
        ssim_total += ssim_score.mean().item()
        ssim_count += 1

    epoch_train_loss = train_loss / len(train_data.dataset)
    epoch_ssim = ssim_total / ssim_count
    print(f"Epoch {epoch:02d} | Loss: {epoch_train_loss:.4f} | SSIM: {epoch_ssim:.4f}")

### 7. Final reconstruction

In [None]:
def show_reconstruction_logscaled(zf, recon, gt, idx=0, eps=1e-5):
    """
    Visualize log-magnitude of zerofill, reconstruction, and ground truth
    with shared vmin/vmax and a single colorbar.

    Parameters:
        zf, recon, gt: complex-valued tensors [B, 1, H, W]
        idx: index of the image in the batch to visualize
        eps: stability epsilon for log
    """
    # Compute log-magnitude images
    zf_mag = np.log(np.abs(zf[idx, 0].cpu().detach().numpy()) + eps)
    recon_mag = np.log(np.abs(recon[idx, 0].cpu().detach().numpy()) + eps)
    gt_mag = np.log(np.abs(gt[idx, 0].cpu().detach().numpy()) + eps)

    # Stack for shared color scaling
    all_imgs = np.stack([zf_mag, recon_mag, gt_mag], axis=0)
    vmin, vmax = np.min(all_imgs), np.max(all_imgs)

    titles = ['Zerofill', 'Reconstruction', 'Ground Truth']
    images = [zf_mag, recon_mag, gt_mag]

    fig, axes = plt.subplots(1, 3, figsize=(10, 10), constrained_layout=True)

    for ax, img, title in zip(axes, images, titles):
        im = ax.imshow(img, cmap='gray', vmin=vmin, vmax=vmax)
        ax.set_title(title, fontsize=14)
        ax.axis('off')

    # Shared colorbar
    fig.colorbar(im, ax=axes, fraction=0.015, pad=0.04)
    plt.show()

In [None]:
for i, (zf, us, cs, gt) in enumerate(train_data):
    print(f"Batch {i}:")
    print(f"  ZF shape: {zf.shape}")  # (B, 1, 320, 320)
    print(f"  US shape: {us.shape}")  # (B, 15, 640, 115)
    print(f"  CS shape: {cs.shape}")  # (B, 15, 320, 320)
    print(f"  GT shape: {gt.shape}")  # (B, 1, 320, 320)
    break

zf, us, cs, gt = zf.to(device), us.to(device), cs.to(device), gt.to(device)
output = reconstructor(zf, us, cs)

In [None]:
show_reconstruction_logscaled(zf, output, gt, idx = 31, eps=1e-4)