# Environment Setup

In [26]:
import os
import numpy as np
import random
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import time
from pathlib import Path
from skimage.metrics import structural_similarity as ssim
from fastmri.data import transforms as T
from fastmri.data.subsample import RandomMaskFunc

# Set random seeds for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


# Dataset Class Definition

In [27]:
class ProcessedFastMRIDataset(Dataset):
    """Dataset for loading preprocessed FastMRI data or creating from raw files"""
    
    def __init__(self, data_dir=None, file_list=None, mode='train', mask_func=None, use_processed=True):
        self.mode = mode
        self.use_processed = use_processed
        
        if use_processed:
            self.data_dir = os.path.join(data_dir, mode)
            try:
                all_files_in_dir = os.listdir(self.data_dir)
            except FileNotFoundError:
                raise FileNotFoundError(f"Data directory not found: {self.data_dir}")

            # Initial list of all .pt files
            all_pt_files = sorted([os.path.join(self.data_dir, f) for f in all_files_in_dir if f.endswith('.pt')])

            # Filter out files with "metadata" in their name
            self.batch_files = []
            for f_path_str in all_pt_files:
                if "metadata" not in Path(f_path_str).name:
                    self.batch_files.append(f_path_str)
                else:
                    print(f"Filtering out metadata file: {f_path_str}")
            
            if not self.batch_files:
                raise FileNotFoundError(f"No valid .pt files (after filtering 'metadata' files) found in {self.data_dir}")
                
            self.examples = []
            print(f"Found {len(self.batch_files)} .pt files to process in {self.data_dir} after filtering.")
            
            for i, batch_file_path in enumerate(self.batch_files):
                print(f"Processing file: {batch_file_path}")
                try:
                    batch = torch.load(batch_file_path, map_location='cpu') # Load on CPU

                    if not isinstance(batch, dict):
                        print(f"  Warning: Skipped {batch_file_path}. Loaded object is not a dictionary (type: {type(batch)}).")
                        continue

                    if 'inputs' not in batch:
                        print(f"  Warning: Skipped {batch_file_path}. Missing 'inputs' key. Keys present: {list(batch.keys())}.")
                        continue
                    
                    if 'targets' not in batch: # It's good practice to check for targets too
                        print(f"  Warning: Skipped {batch_file_path}. Missing 'targets' key. Keys present: {list(batch.keys())}.")
                        continue
                        
                    num_samples = len(batch['inputs'])
                    self.examples.extend([(i, j) for j in range(num_samples)])
                    # print(f"  Successfully loaded {num_samples} samples from {batch_file_path}.")

                except Exception as e:
                    print(f"  Error loading or processing file {batch_file_path}: {e}. Skipping.")
            
            if not self.examples:
                raise ValueError(f"No valid examples could be loaded from {self.data_dir}. Check warnings above.")
            
            print(f"Successfully loaded a total of {len(self.examples)} examples.")

        else:
            # This part is for use_processed=False, ensure it's complete from your original code
            self.file_list = file_list
            self.mask_func = mask_func
            self.examples = []
            if not self.file_list:
                 raise ValueError("File list is empty when use_processed is False.")
            
            print(f"Processing {len(self.file_list)} raw files...")
            for fpath_obj in self.file_list: # Assuming file_list contains Path objects or strings
                fpath = str(fpath_obj) # Ensure it's a string for h5py
                try:
                    with h5py.File(fpath, 'r') as hf:
                        kspace = hf['kspace']
                        middle_slice = kspace.shape[0] // 2
                        self.examples.append((fpath, middle_slice))
                except Exception as e:
                    print(f"Error processing raw file {fpath}: {e}")
            
            if not self.examples:
                raise ValueError("No examples could be prepared from raw files.")
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        if self.use_processed:
            batch_idx, sample_idx = self.examples[idx]
            batch_data = torch.load(self.batch_files[batch_idx], map_location='cpu')
            inputs = batch_data['inputs'][sample_idx]
            targets = batch_data['targets'][sample_idx]
        else:
            # This part is for use_processed=False, ensure it's complete from your original code
            fpath, slice_idx = self.examples[idx]
            
            with h5py.File(fpath, 'r') as hf:
                kspace = hf['kspace'][slice_idx]
                target_rss = hf['reconstruction_rss'][slice_idx] if 'reconstruction_rss' in hf else None
                
                kspace_tensor = T.to_tensor(kspace)
                # Fallback for max_value if not in attrs, ensure this value is appropriate
                max_value = hf.attrs.get('max', 0.00085) 
                
                if self.mask_func:
                    masked_kspace, mask, _ = T.apply_mask(kspace_tensor, self.mask_func)
                else:
                    masked_kspace = kspace_tensor
                    
                image = T.ifft2c(masked_kspace)
                image_abs = T.complex_abs(image)
                image_normalized = image_abs / max_value
                
                if target_rss is not None:
                    target_tensor = T.to_tensor(target_rss)
                    target_normalized = target_tensor / max_value
                else:
                    # Handle cases where target might be missing, perhaps raise error or return placeholder
                    # For now, let's assume target_rss is always present or this case is handled
                    raise ValueError(f"Target 'reconstruction_rss' not found in {fpath}")

                crop_shape = (320, 320) # Assuming T is fastmri.data.transforms
                inputs = T.center_crop(image_normalized, crop_shape)
                targets = T.center_crop(target_normalized, crop_shape)
        
        # Ensure inputs and targets are 2D tensors [height, width]
        # The model expects [batch_size, channels, height, width]
        # DataLoader will batch them. The model adds channel dim if input is 3D [B, H, W].
        # Here, __getitem__ should return single sample, typically [H,W] or [C,H,W]
        # If your model needs a channel dim here, .unsqueeze(0) might be needed for inputs/targets
        return inputs, targets


# UNet Model Definition

In [28]:
class ConvBlock(nn.Module):
    """
    A Convolutional Block that consists of two convolution layers each followed by
    instance normalization, LeakyReLU activation and dropout.
    """
    def __init__(self, in_chans, out_chans, drop_prob=0.0):
        """
        Args:
            in_chans (int): Number of channels in the input.
            out_chans (int): Number of channels in the output.
            drop_prob (float): Dropout probability.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.drop_prob = drop_prob

        self.layers = nn.Sequential(
            nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(out_chans),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Dropout2d(drop_prob),
            nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(out_chans),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Dropout2d(drop_prob)
        )

    def forward(self, x):
        return self.layers(x)


class UNet(nn.Module):
    """
    PyTorch implementation of a U-Net model for MRI reconstruction.
    This is the standard U-Net architecture used as a baseline for fastMRI.
    """
    def __init__(self, in_chans=1, out_chans=1, chans=32, num_pool_layers=4, drop_prob=0.0):
        """
        Args:
            in_chans (int): Number of channels in the input to the U-Net model.
            out_chans (int): Number of channels in the output to the U-Net model.
            chans (int): Number of output channels of the first convolution layer.
            num_pool_layers (int): Number of down-sampling and up-sampling layers.
            drop_prob (float): Dropout probability.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.chans = chans
        self.num_pool_layers = num_pool_layers
        self.drop_prob = drop_prob

        # Downsampling path
        self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)])
        ch = chans
        for _ in range(num_pool_layers - 1):
            self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob))
            ch *= 2
        
        # Bottleneck
        self.conv = ConvBlock(ch, ch * 2, drop_prob)
        ch *= 2

        # Upsampling path
        self.up_conv = nn.ModuleList()
        self.up_transpose_conv = nn.ModuleList()
        for _ in range(num_pool_layers):
            self.up_transpose_conv.append(nn.ConvTranspose2d(ch, ch // 2, kernel_size=2, stride=2))
            self.up_conv.append(ConvBlock(ch, ch // 2, drop_prob))
            ch //= 2

        # Final output layer
        self.conv2 = nn.Conv2d(ch, out_chans, kernel_size=1)

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape [batch_size, in_chans, height, width]
        Returns:
            Output tensor of shape [batch_size, out_chans, height, width]
        """
        # Add channel dimension if not present
        if len(x.shape) == 3:
            x = x.unsqueeze(1)
            
        # Apply down-sampling layers
        skip_connections = []
        for i, layer in enumerate(self.down_sample_layers):
            x = layer(x)
            skip_connections.append(x)
            x = F.max_pool2d(x, kernel_size=2, stride=2)

        x = self.conv(x)

        # Apply up-sampling layers
        for i in range(self.num_pool_layers):
            x = self.up_transpose_conv[i](x)
            x = torch.cat([x, skip_connections[-(i + 1)]], dim=1)
            x = self.up_conv[i](x)

        return self.conv2(x)


# Loss Functions and Metrics

In [29]:
class SSIMLoss(nn.Module):
    """SSIM loss module for MRI reconstruction"""
    def __init__(self, win_size=7, k1=0.01, k2=0.03):
        super().__init__()
        self.win_size = win_size
        self.k1, self.k2 = k1, k2
        self.register_buffer('w', torch.ones(1, 1, win_size, win_size) / win_size**2)
        self.cov_norm = win_size**2 / (win_size**2 - 1)
    
    def forward(self, x, y):
        data_range = 1.0  # Images are normalized to [0,1]
        C1 = (self.k1 * data_range)**2
        C2 = (self.k2 * data_range)**2
        
        # Compute means
        ux = F.conv2d(x, self.w)
        uy = F.conv2d(y, self.w)
        
        # Compute variances and covariance
        uxx = F.conv2d(x * x, self.w)
        uyy = F.conv2d(y * y, self.w)
        uxy = F.conv2d(x * y, self.w)
        vx = self.cov_norm * (uxx - ux * ux)
        vy = self.cov_norm * (uyy - uy * uy)
        vxy = self.cov_norm * (uxy - ux * uy)
        
        A1 = 2 * ux * uy + C1
        A2 = 2 * vxy + C2
        B1 = ux**2 + uy**2 + C1
        B2 = vx + vy + C2
        
        D = (A1 * A2) / (B1 * B2)
        return 1 - D.mean()

# Combined L1 and SSIM loss - common in MRI reconstruction
class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.84):
        super().__init__()
        self.alpha = alpha
        self.l1_loss = nn.L1Loss()
        self.ssim_loss = SSIMLoss()
        
    def forward(self, pred, target):
        l1 = self.l1_loss(pred, target)
        ssim = self.ssim_loss(pred, target)
        return self.alpha * l1 + (1 - self.alpha) * ssim

# Calculate PSNR metric
def calculate_psnr(img1, img2):
    """Calculate PSNR between two images"""
    mse = torch.mean((img1 - img2) ** 2)
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

# Calculate SSIM metric for numpy images
def calculate_ssim(img1, img2):
    """Calculate SSIM between two numpy arrays"""
    return ssim(img1, img2, data_range=img1.max())


# Training and Validation Functions

In [30]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    
    with tqdm(dataloader, desc="Training") as pbar:
        for inputs, targets in pbar:
            # Move tensors to the right device
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            # Add channel dimension if needed
            if len(inputs.shape) == 3:
                inputs = inputs.unsqueeze(1)
            if len(targets.shape) == 3:
                targets = targets.unsqueeze(1)
                
            # Forward pass
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Backward and optimize
            loss.backward()
            optimizer.step()
            
            # Update statistics
            running_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})
            
    return running_loss / len(dataloader)

def validate(model, dataloader, criterion, device):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    running_psnr = 0.0
    running_ssim = 0.0
    
    with torch.no_grad():
        with tqdm(dataloader, desc="Validation") as pbar:
            for inputs, targets in pbar:
                # Move tensors to the right device
                inputs = inputs.to(device)
                targets = targets.to(device)
                
                # Add channel dimension if needed
                if len(inputs.shape) == 3:
                    inputs = inputs.unsqueeze(1)
                if len(targets.shape) == 3:
                    targets = targets.unsqueeze(1)
                
                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                # Calculate metrics
                psnr = calculate_psnr(outputs, targets)
                
                # Update statistics
                running_loss += loss.item()
                running_psnr += psnr.item()
                
                # Calculate SSIM on CPU for one batch (it's slower)
                if running_ssim == 0:
                    for i in range(min(4, outputs.size(0))):  # Calculate for first 4 images only
                        output_np = outputs[i, 0].cpu().numpy()
                        target_np = targets[i, 0].cpu().numpy()
                        running_ssim += calculate_ssim(output_np, target_np)
                    running_ssim /= min(4, outputs.size(0))
                    
                pbar.set_postfix({'val_loss': loss.item(), 'psnr': psnr.item()})
                
    return running_loss / len(dataloader), running_psnr / len(dataloader), running_ssim


# Visualization Functions

In [31]:
def visualize_results(model, dataloader, device, epoch, output_dir):
    """Visualize model predictions on a few samples"""
    model.eval()
    
    # Get a batch of data
    inputs, targets = next(iter(dataloader))
    inputs = inputs.to(device)
    targets = targets.to(device)
    
    # Make predictions
    with torch.no_grad():
        outputs = model(inputs.unsqueeze(1) if len(inputs.shape) == 3 else inputs)
    
    # Create figure
    fig, axes = plt.subplots(4, 3, figsize=(15, 20))
    fig.suptitle(f"FastMRI Reconstruction - Epoch {epoch}", fontsize=16)
    
    for i in range(4):  # Show 4 examples
        # Get images
        if i < inputs.size(0):
            input_img = inputs[i].cpu().numpy()
            output_img = outputs[i, 0].cpu().numpy()
            target_img = targets[i].cpu().numpy()
            
            # Display input
            axes[i, 0].imshow(input_img, cmap='gray')
            axes[i, 0].set_title(f"Input (Undersampled)")
            axes[i, 0].axis('off')
            
            # Display output
            axes[i, 1].imshow(output_img, cmap='gray')
            axes[i, 1].set_title(f"Prediction")
            axes[i, 1].axis('off')
            
            # Display target
            axes[i, 2].imshow(target_img, cmap='gray')
            axes[i, 2].set_title(f"Ground Truth")
            axes[i, 2].axis('off')
            
            # Calculate metrics
            psnr = calculate_psnr(
                torch.from_numpy(output_img), 
                torch.from_numpy(target_img)
            ).item()
            ssim_val = calculate_ssim(output_img, target_img)
            
            # Add metrics as text
            axes[i, 1].text(
                10, 20, 
                f'PSNR: {psnr:.2f} dB\nSSIM: {ssim_val:.4f}',
                color='white', fontsize=12, 
                bbox=dict(facecolor='black', alpha=0.5)
            )
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.95)
    
    # Save figure
    os.makedirs(output_dir, exist_ok=True)
    plt.savefig(f"{output_dir}/epoch_{epoch}.png")
    plt.close()
    
    return fig


In [32]:
from pathlib import Path # Ensure Path is imported in this cell or globally

# ... (assuming _orig_init is correctly defined from the non-patched ProcessedFastMRIDataset.__init__)

def _patched_init(self, *args, **kwargs):
    _orig_init(self, *args, **kwargs)

    good_examples = []
    if hasattr(self, 'batch_files') and self.batch_files: # Check if batch_files exists and is not empty
        for idx, slice_idx in self.examples:
            # Ensure idx is a valid index for self.batch_files
            if 0 <= idx < len(self.batch_files):
                batch_file_string_path = self.batch_files[idx]
                # Convert the string path to a Path object
                path_obj = Path(batch_file_string_path)
                if "metadata" not in path_obj.name:  # Now path_obj.name is correct
                    good_examples.append((idx, slice_idx))
            else:
                print(f"Warning in patch: Invalid index {idx} for batch_files of length {len(self.batch_files)}")
        self.examples = good_examples
    elif not hasattr(self, 'batch_files') or not self.batch_files:
        # This case might occur if _orig_init failed to populate self.batch_files
        # or if use_processed was False (though the error trace suggests use_processed=True)
        print("Warning in patch: self.batch_files not populated or empty, skipping example filtering.")
        # self.examples would remain as whatever _orig_init set it to, or cause error if not set.

# ProcessedFastMRIDataset.__init__ = _patched_init # This line applies the patch
# print("Patched ProcessedFastMRIDataset to ignore files that contain 'metadata'")


# Main Training Script

In [33]:
def train_unet_model():
    # Parameters
    batch_size = 16
    epochs = 50
    lr = 1e-4
    weight_decay = 1e-4
    
    # Paths
    data_dir = Path("/workspace/fastmri-reconstruction/processed_fastmri_data")
    output_dir = "./unet_output"
    Path(output_dir).mkdir(exist_ok=True)
    
    # Load datasets
    train_dataset = ProcessedFastMRIDataset(
    data_dir      = data_dir,
    mode          = "train",
    use_processed = True,
    )

    val_dataset = ProcessedFastMRIDataset(
        data_dir      = data_dir,
        mode          = "val",
        use_processed = True,
    )

    
    print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples")
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=4, 
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=4, 
        pin_memory=True
    )
    
    # Create model, loss, optimizer
    model = UNet(
        in_chans=1,
        out_chans=1,
        chans=32,  # Start with 32 channels as in fastMRI baseline
        num_pool_layers=4,
        drop_prob=0.0
    ).to(device)
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
    
    # Loss function - combination of L1 and SSIM
    criterion = CombinedLoss(alpha=0.84).to(device) # Move the entire criterion to the GPU
    
    # Optimizer
    optimizer = torch.optim.Adam(
        model.parameters(), 
        lr=lr, 
        weight_decay=weight_decay
    )
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.5, 
        patience=5, 
    )
    
    # Training loop
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    val_psnrs = []
    val_ssims = []
    
    print("Starting training...")
    for epoch in range(1, epochs + 1):
        print(f"\nEpoch {epoch}/{epochs}")
        
        # Train
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
        train_losses.append(train_loss)
        
        # Validate
        val_loss, val_psnr, val_ssim = validate(model, val_loader, criterion, device)
        val_losses.append(val_loss)
        val_psnrs.append(val_psnr)
        val_ssims.append(val_ssim)
        
        # Update learning rate
        scheduler.step(val_loss)
        
        # Print stats
        print(f"Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | PSNR: {val_psnr:.2f} dB | SSIM: {val_ssim:.4f}")
        
        # Visualize results every 5 epochs
        if epoch % 5 == 0 or epoch == epochs:
            visualize_results(model, val_loader, device, epoch, output_dir)
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'psnr': val_psnr,
                'ssim': val_ssim
            }, f"{output_dir}/best_model.pth")
            print(f"Saved best model at epoch {epoch}")
            
        # Save checkpoint every 10 epochs
        if epoch % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'psnr': val_psnr,
                'ssim': val_ssim
            }, f"{output_dir}/checkpoint_epoch_{epoch}.pth")
    
    # Save final model
    torch.save({
        'epoch': epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_losses[-1],
        'val_loss': val_losses[-1],
        'psnr': val_psnrs[-1],
        'ssim': val_ssims[-1]
    }, f"{output_dir}/final_model.pth")
    
    # Plot training curves
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss Curves')
    
    plt.subplot(1, 3, 2)
    plt.plot(val_psnrs)
    plt.xlabel('Epoch')
    plt.ylabel('PSNR (dB)')
    plt.title('PSNR')
    
    plt.subplot(1, 3, 3)
    plt.plot(val_ssims)
    plt.xlabel('Epoch')
    plt.ylabel('SSIM')
    plt.title('SSIM')
    
    plt.tight_layout()
    plt.savefig(f"{output_dir}/training_curves.png")
    plt.close()
    
    return model, train_losses, val_losses, val_psnrs, val_ssims

# Execute the training
if __name__ == "__main__":
    model, train_losses, val_losses, val_psnrs, val_ssims = train_unet_model()


Filtering out metadata file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_metadata.pt
Found 61 .pt files to process in /workspace/fastmri-reconstruction/processed_fastmri_data/train after filtering.
Processing file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_batch_0000.pt
Processing file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_batch_0001.pt
Processing file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_batch_0002.pt
Processing file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_batch_0003.pt
Processing file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_batch_0004.pt
Processing file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_batch_0005.pt
Processing file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_bat

Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.207893 | Val Loss: 0.105392 | PSNR: 23.30 dB | SSIM: 0.5866
Saved best model at epoch 1

Epoch 2/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.086673 | Val Loss: 0.079195 | PSNR: 25.01 dB | SSIM: 0.6768
Saved best model at epoch 2

Epoch 3/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.073843 | Val Loss: 0.075885 | PSNR: 25.49 dB | SSIM: 0.6557
Saved best model at epoch 3

Epoch 4/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.069361 | Val Loss: 0.070397 | PSNR: 25.91 dB | SSIM: 0.6452
Saved best model at epoch 4

Epoch 5/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.066259 | Val Loss: 0.068550 | PSNR: 26.09 dB | SSIM: 0.6361
Saved best model at epoch 5

Epoch 6/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.064212 | Val Loss: 0.064543 | PSNR: 26.54 dB | SSIM: 0.6926
Saved best model at epoch 6

Epoch 7/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.061781 | Val Loss: 0.062506 | PSNR: 26.84 dB | SSIM: 0.7075
Saved best model at epoch 7

Epoch 8/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.059999 | Val Loss: 0.065396 | PSNR: 26.56 dB | SSIM: 0.6434

Epoch 9/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.060787 | Val Loss: 0.065373 | PSNR: 26.67 dB | SSIM: 0.6327

Epoch 10/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.059233 | Val Loss: 0.060974 | PSNR: 27.13 dB | SSIM: 0.6393
Saved best model at epoch 10

Epoch 11/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.057333 | Val Loss: 0.058699 | PSNR: 27.30 dB | SSIM: 0.6996
Saved best model at epoch 11

Epoch 12/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.057335 | Val Loss: 0.059036 | PSNR: 27.27 dB | SSIM: 0.6752

Epoch 13/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.055872 | Val Loss: 0.059657 | PSNR: 27.06 dB | SSIM: 0.6755

Epoch 14/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.056339 | Val Loss: 0.059152 | PSNR: 27.16 dB | SSIM: 0.6885

Epoch 15/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.055558 | Val Loss: 0.056544 | PSNR: 27.55 dB | SSIM: 0.7006
Saved best model at epoch 15

Epoch 16/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.054752 | Val Loss: 0.057918 | PSNR: 27.39 dB | SSIM: 0.6963

Epoch 17/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.054699 | Val Loss: 0.057303 | PSNR: 27.63 dB | SSIM: 0.6744

Epoch 18/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.054567 | Val Loss: 0.056362 | PSNR: 27.39 dB | SSIM: 0.6379
Saved best model at epoch 18

Epoch 19/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.052763 | Val Loss: 0.056682 | PSNR: 27.23 dB | SSIM: 0.6782

Epoch 20/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.053229 | Val Loss: 0.055126 | PSNR: 27.87 dB | SSIM: 0.6745
Saved best model at epoch 20

Epoch 21/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.052672 | Val Loss: 0.055567 | PSNR: 27.65 dB | SSIM: 0.7027

Epoch 22/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.052818 | Val Loss: 0.056066 | PSNR: 27.48 dB | SSIM: 0.6689

Epoch 23/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.052038 | Val Loss: 0.054781 | PSNR: 27.79 dB | SSIM: 0.6737
Saved best model at epoch 23

Epoch 24/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.051645 | Val Loss: 0.055318 | PSNR: 27.62 dB | SSIM: 0.6649

Epoch 25/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.051221 | Val Loss: 0.054384 | PSNR: 27.72 dB | SSIM: 0.6716
Saved best model at epoch 25

Epoch 26/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.051205 | Val Loss: 0.056845 | PSNR: 27.35 dB | SSIM: 0.6936

Epoch 27/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.050661 | Val Loss: 0.054085 | PSNR: 27.88 dB | SSIM: 0.6892
Saved best model at epoch 27

Epoch 28/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.049455 | Val Loss: 0.054385 | PSNR: 27.78 dB | SSIM: 0.6946

Epoch 29/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.049608 | Val Loss: 0.054141 | PSNR: 27.99 dB | SSIM: 0.6649

Epoch 30/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.049638 | Val Loss: 0.055173 | PSNR: 27.71 dB | SSIM: 0.6955

Epoch 31/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.049581 | Val Loss: 0.055338 | PSNR: 27.65 dB | SSIM: 0.6633

Epoch 32/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.049286 | Val Loss: 0.054257 | PSNR: 27.81 dB | SSIM: 0.6961

Epoch 33/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.048342 | Val Loss: 0.055467 | PSNR: 27.58 dB | SSIM: 0.6951

Epoch 34/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.046558 | Val Loss: 0.053275 | PSNR: 27.98 dB | SSIM: 0.6716
Saved best model at epoch 34

Epoch 35/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.045241 | Val Loss: 0.053534 | PSNR: 27.90 dB | SSIM: 0.6923

Epoch 36/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.044426 | Val Loss: 0.054048 | PSNR: 27.79 dB | SSIM: 0.6696

Epoch 37/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.044505 | Val Loss: 0.054789 | PSNR: 27.78 dB | SSIM: 0.6694

Epoch 38/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.043851 | Val Loss: 0.054411 | PSNR: 27.71 dB | SSIM: 0.6943

Epoch 39/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.043435 | Val Loss: 0.053669 | PSNR: 27.87 dB | SSIM: 0.6738

Epoch 40/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.043061 | Val Loss: 0.053414 | PSNR: 27.95 dB | SSIM: 0.6886

Epoch 41/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.041339 | Val Loss: 0.052908 | PSNR: 27.94 dB | SSIM: 0.6832
Saved best model at epoch 41

Epoch 42/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.040585 | Val Loss: 0.052954 | PSNR: 28.02 dB | SSIM: 0.6848

Epoch 43/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.040302 | Val Loss: 0.053621 | PSNR: 27.85 dB | SSIM: 0.6764

Epoch 44/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.039779 | Val Loss: 0.053630 | PSNR: 27.80 dB | SSIM: 0.6874

Epoch 45/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.039591 | Val Loss: 0.053100 | PSNR: 27.90 dB | SSIM: 0.6908

Epoch 46/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.039312 | Val Loss: 0.052772 | PSNR: 27.96 dB | SSIM: 0.6700
Saved best model at epoch 46

Epoch 47/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.039140 | Val Loss: 0.053325 | PSNR: 27.90 dB | SSIM: 0.6817

Epoch 48/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.039113 | Val Loss: 0.053539 | PSNR: 27.86 dB | SSIM: 0.6723

Epoch 49/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.038684 | Val Loss: 0.053560 | PSNR: 27.79 dB | SSIM: 0.6834

Epoch 50/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.038647 | Val Loss: 0.052595 | PSNR: 28.03 dB | SSIM: 0.6935
Saved best model at epoch 50


# Inference and Model Evaluation

In [34]:
def evaluate_model(model_path, dataloader, device, output_dir="./evaluation"):
    """Evaluate a trained model on a dataset"""
    # Load the model
    model = UNet(in_chans=1, out_chans=1, chans=32, num_pool_layers=4).to(device)
    
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize metrics
    psnrs = []
    ssims = []
    
    # Process batches
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(tqdm(dataloader, desc="Evaluating")):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            # Add channel dimension if needed
            if len(inputs.shape) == 3:
                inputs = inputs.unsqueeze(1)
            if len(targets.shape) == 3:
                targets = targets.unsqueeze(1)
            
            # Make predictions
            outputs = model(inputs)
            
            # Calculate metrics
            for j in range(outputs.size(0)):
                output_np = outputs[j, 0].cpu().numpy()
                target_np = targets[j, 0].cpu().numpy()
                
                psnr = calculate_psnr(outputs[j:j+1], targets[j:j+1]).item()
                ssim_val = calculate_ssim(output_np, target_np)
                
                psnrs.append(psnr)
                ssims.append(ssim_val)
            
            # Visualize first batch
            if i == 0:
                fig, axes = plt.subplots(4, 3, figsize=(15, 20))
                fig.suptitle("FastMRI Reconstruction Results", fontsize=16)
                
                for j in range(min(4, outputs.size(0))):
                    # Get images
                    input_img = inputs[j, 0].cpu().numpy()
                    output_img = outputs[j, 0].cpu().numpy()
                    target_img = targets[j, 0].cpu().numpy()
                    
                    # Display input
                    axes[j, 0].imshow(input_img, cmap='gray')
                    axes[j, 0].set_title(f"Input (Undersampled)")
                    axes[j, 0].axis('off')
                    
                    # Display output
                    axes[j, 1].imshow(output_img, cmap='gray')
                    axes[j, 1].set_title(f"Prediction")
                    axes[j, 1].axis('off')
                    
                    # Display target
                    axes[j, 2].imshow(target_img, cmap='gray')
                    axes[j, 2].set_title(f"Ground Truth")
                    axes[j, 2].axis('off')
                    
                    # Add metrics as text
                    axes[j, 1].text(
                        10, 20, 
                        f'PSNR: {psnrs[j]:.2f} dB\nSSIM: {ssims[j]:.4f}',
                        color='white', fontsize=12, 
                        bbox=dict(facecolor='black', alpha=0.5)
                    )
                
                plt.tight_layout()
                plt.subplots_adjust(top=0.95)
                plt.savefig(f"{output_dir}/evaluation_samples.png")
                plt.close()
    
    # Calculate average metrics
    avg_psnr = np.mean(psnrs)
    avg_ssim = np.mean(ssims)
    
    # Print results
    print(f"Average PSNR: {avg_psnr:.2f} dB")
    print(f"Average SSIM: {avg_ssim:.4f}")
    
    # Save metrics
    results = {
        'psnr': psnrs,
        'ssim': ssims,
        'avg_psnr': avg_psnr,
        'avg_ssim': avg_ssim
    }
    
    # Save in text file
    with open(f"{output_dir}/evaluation_results.txt", 'w') as f:
        f.write(f"U-Net Baseline Evaluation Results\n")
        f.write(f"Average PSNR: {avg_psnr:.2f} dB\n")
        f.write(f"Average SSIM: {avg_ssim:.4f}\n")
    
    # Plot histograms of metrics
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.hist(psnrs, bins=20)
    plt.xlabel('PSNR (dB)')
    plt.ylabel('Count')
    plt.title(f'PSNR Histogram (Avg: {avg_psnr:.2f} dB)')
    
    plt.subplot(1, 2, 2)
    plt.hist(ssims, bins=20)
    plt.xlabel('SSIM')
    plt.ylabel('Count')
    plt.title(f'SSIM Histogram (Avg: {avg_ssim:.4f})')
    
    plt.tight_layout()
    plt.savefig(f"{output_dir}/metric_histograms.png")
    plt.close()
    
    return results

# Example usage (uncomment to run)
# val_dataset = ProcessedFastMRIDataset(data_dir="./processed_fastmri_data", mode='val', use_processed=True)
# val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
# results = evaluate_model("./unet_output/best_model.pth", val_loader, device)


# Run

In [None]:
if __name__ == "__main__":
    # Training
    print("Starting UNet baseline training for FastMRI...")
    model, train_losses, val_losses, val_psnrs, val_ssims = train_unet_model()
    
    # Evaluation
    print("\nEvaluating the best model...")
    val_dataset = ProcessedFastMRIDataset(
        data_dir="./processed_fastmri_data", 
        mode='val', 
        use_processed=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=16, 
        shuffle=False, 
        num_workers=4
    )
    results = evaluate_model("./unet_output/best_model.pth", val_loader, device)
    
    print("\nBaseline U-Net training and evaluation complete!")


Starting UNet baseline training for FastMRI...
Filtering out metadata file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_metadata.pt
Found 61 .pt files to process in /workspace/fastmri-reconstruction/processed_fastmri_data/train after filtering.
Processing file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_batch_0000.pt
Processing file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_batch_0001.pt
Processing file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_batch_0002.pt
Processing file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_batch_0003.pt
Processing file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_batch_0004.pt
Processing file: /workspace/fastmri-reconstruction/processed_fastmri_data/train/fastmri_train_4x_batch_0005.pt
Processing file: /workspace/fastmri-reconstruction/pr

Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.224974 | Val Loss: 0.102148 | PSNR: 22.94 dB | SSIM: 0.6178
Saved best model at epoch 1

Epoch 2/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.082079 | Val Loss: 0.074428 | PSNR: 25.56 dB | SSIM: 0.6475
Saved best model at epoch 2

Epoch 3/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.070839 | Val Loss: 0.068385 | PSNR: 26.24 dB | SSIM: 0.6604
Saved best model at epoch 3

Epoch 4/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.066534 | Val Loss: 0.065712 | PSNR: 26.53 dB | SSIM: 0.6944
Saved best model at epoch 4

Epoch 5/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.064120 | Val Loss: 0.064310 | PSNR: 26.68 dB | SSIM: 0.6226
Saved best model at epoch 5

Epoch 6/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.061466 | Val Loss: 0.061933 | PSNR: 26.96 dB | SSIM: 0.6416
Saved best model at epoch 6

Epoch 7/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.060840 | Val Loss: 0.062556 | PSNR: 26.79 dB | SSIM: 0.6104

Epoch 8/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.058583 | Val Loss: 0.059169 | PSNR: 27.36 dB | SSIM: 0.6205
Saved best model at epoch 8

Epoch 9/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.057862 | Val Loss: 0.059192 | PSNR: 27.28 dB | SSIM: 0.6574

Epoch 10/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.057348 | Val Loss: 0.058831 | PSNR: 27.55 dB | SSIM: 0.6384
Saved best model at epoch 10

Epoch 11/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.055812 | Val Loss: 0.056779 | PSNR: 27.68 dB | SSIM: 0.6446
Saved best model at epoch 11

Epoch 12/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.055217 | Val Loss: 0.057698 | PSNR: 27.60 dB | SSIM: 0.6658

Epoch 13/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.054788 | Val Loss: 0.056597 | PSNR: 27.81 dB | SSIM: 0.6620
Saved best model at epoch 13

Epoch 14/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.054277 | Val Loss: 0.055820 | PSNR: 27.79 dB | SSIM: 0.6641
Saved best model at epoch 14

Epoch 15/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.053984 | Val Loss: 0.055445 | PSNR: 27.84 dB | SSIM: 0.6443
Saved best model at epoch 15

Epoch 16/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.053282 | Val Loss: 0.055264 | PSNR: 27.88 dB | SSIM: 0.6592
Saved best model at epoch 16

Epoch 17/50


Training:   0%|          | 0/61 [00:00<?, ?it/s]