In [1]:
%matplotlib widget

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import os
import wandb
import tifffile
import psutil
from glob import glob
import time
import datetime
import imageio
from sys import stdout
from typing import List, Tuple, Optional
from tqdm.auto import tqdm 

print(f"PyTorch version: {torch.__version__}")

PyTorch version: 2.6.0+cu124


In [3]:
class Args:
    def __init__(self):
        self.scale = 4                    # Upscaling factor
        self.epoch = 500                  # Number of epochs
        self.epoch_step = 50              # T_0 for CosineAnnealingWarmRestarts
        self.patch_size = 192              # Training crop size
        self.batch_size = 16              # Batch size
        self.lr = 0.0001                  # Learning rate
        self.save_freq = 10               # Model saving frequency
        self.print_freq = 1               # Validation frequency
        self.itersPerEpoch = 200          # Iterations per epoch
        self.iterCyclesPerEpoch = 1       # Iteration cycles per epoch
        self.valNum = 5                   # Number of validation samples
        self.checkpoint_dir = './checkpoints'  # Checkpoints directory
        self.test_dir = './Dataset/Bentheimer_mixed_fw90/Test'          # Test directory
        self.test_save_dir = './test_results'  # Test results directory
        self.test_temp_save_dir = './test_temp'  # Temp test results directory
        self.modelName = 'DualSR_mixed90'         # Model name
        self.dataset_dir = './Dataset/Bentheimer_mixed_fw90/Train'  # Dataset directory
        self.continue_train = False       # Continue training from checkpoint
        self.continueEpoch = 0            # Epoch to continue from
        self.use_best_model = True        # Use best model for testing
        self.phase = 'train'              # Phase: train or test
        self.augFlag = True               # Data augmentation flag
        self.valTest = False              # Validation test flag
        self.numResBlocks = 16             # Number of residual blocks
        self.ngsrf = 64                   # Number of filters for generator SR
        self.gpuIDs = "0"                 # GPU IDs to use
        self.chunk_size = 1000            # Size of chunks for processing large images
        self.overlap = 50                 # Overlap between chunks
        self.wandb_api_key = None         # API key for wandb authentication
        self.wandb_project = "DualSR_Bentheimer_mixed90"   # Project name for wandb
        self.wandb_entity = None          # Optional: your wandb username or team name
        self.eta_min = 1e-6               # Minimum learning rate for scheduler

In [4]:
# Network components
class InstanceNorm(nn.Module):
    """Instance Normalization Layer"""
    def __init__(self, channels, eps=1e-5):
        super(InstanceNorm, self).__init__()
        self.scale = nn.Parameter(torch.ones(channels))
        self.offset = nn.Parameter(torch.zeros(channels))
        self.eps = eps

    def forward(self, x):
        if len(x.shape) == 4:  # 2D
            mean = x.mean(dim=(2, 3), keepdim=True)
            var = ((x - mean)**2).mean(dim=(2, 3), keepdim=True)
        else:  # 3D
            mean = x.mean(dim=(2, 3, 4), keepdim=True)
            var = ((x - mean)**2).mean(dim=(2, 3, 4), keepdim=True)
            
        x = (x - mean) / torch.sqrt(var + self.eps)
        return x * self.scale.view(1, -1, 1, 1) + self.offset.view(1, -1, 1, 1)

class ResBlock(nn.Module):
    """Residual Block for EDSR"""
    def __init__(self, channels, kernel_size=3, norm_type='instancenorm', apply_norm=False, ndims=2):
        super(ResBlock, self).__init__()
        padding = kernel_size // 2
        
        if ndims == 2:
            self.conv1 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
            self.conv2 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        else:
            self.conv1 = nn.Conv3d(channels, channels, kernel_size, padding=padding)
            self.conv2 = nn.Conv3d(channels, channels, kernel_size, padding=padding)
            
        self.relu = nn.ReLU(inplace=True)
        
        if apply_norm:
            if norm_type.lower() == 'batchnorm':
                self.norm = nn.BatchNorm2d(channels) if ndims == 2 else nn.BatchNorm3d(channels)
            elif norm_type.lower() == 'instancenorm':
                self.norm = InstanceNorm(channels)
            self.use_norm = True
        else:
            self.use_norm = False
            
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        
        if self.use_norm:
            out = self.norm(out)
            
        out = self.conv2(out)
        out += residual
        return out

class Upsample(nn.Module):
    """Upsampling module for EDSR"""
    def __init__(self, scale, num_filters, norm_type='instancenorm', apply_norm=False, ndims=2):
        super(Upsample, self).__init__()
        self.scale = scale
        self.ndims = ndims
        self.num_filters = num_filters
        
        if ndims == 2:
            self.conv = nn.Conv2d(num_filters, num_filters, 3, padding=1)
        else:
            self.conv = nn.Conv3d(num_filters, num_filters, 3, padding=1)
            
        self.relu = nn.ReLU(inplace=True)
        
        if apply_norm:
            if norm_type.lower() == 'batchnorm':
                self.norm = nn.BatchNorm2d(num_filters) if ndims == 2 else nn.BatchNorm3d(num_filters)
            elif norm_type.lower() == 'instancenorm':
                self.norm = InstanceNorm(num_filters)
            self.use_norm = True
        else:
            self.use_norm = False
            
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        
        if self.use_norm:
            x = self.norm(x)
            
        if self.ndims == 2:
            x = F.interpolate(x, scale_factor=self.scale, mode='nearest-exact')
        else:
            x = F.interpolate(x, scale_factor=self.scale, mode='nearest-exact')
            
        return x

class Upsample1D(nn.Module):
    """1D Upsampling module for EDSR"""
    def __init__(self, scale, num_filters, norm_type='instancenorm', apply_norm=False, ndims=2):
        super(Upsample1D, self).__init__()
        self.scale = scale
        self.ndims = ndims
        self.num_filters = num_filters
        
        if ndims == 2:
            self.conv = nn.Conv2d(num_filters, num_filters, 3, padding=1)
        else:
            self.conv = nn.Conv3d(num_filters, num_filters, 3, padding=1)
            
        self.relu = nn.ReLU(inplace=True)
        
        if apply_norm:
            if norm_type.lower() == 'batchnorm':
                self.norm = nn.BatchNorm2d(num_filters) if ndims == 2 else nn.BatchNorm3d(num_filters)
            elif norm_type.lower() == 'instancenorm':
                self.norm = InstanceNorm(num_filters)
            self.use_norm = True
        else:
            self.use_norm = False
            
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        
        if self.use_norm:
            x = self.norm(x)
            
        # Handle upsampling for 2D tensors by using scale_factor with tuple
        if self.ndims == 2:
            # Scale only height (first dimension after batch and channel)
            scale_factor = (self.scale, 1)
            x = F.interpolate(x, scale_factor=scale_factor, mode='nearest-exact')
        else:
            # Scale only depth (first dimension after batch and channel) for 3D
            scale_factor = (self.scale, 1, 1)
            x = F.interpolate(x, scale_factor=scale_factor, mode='nearest-exact')
            
        return x

class EDSR(nn.Module):
    """EDSR Super-Resolution Network"""
    def __init__(self, scale, num_filters=64, num_res_blocks=8, ndims=2):
        super(EDSR, self).__init__()
        self.ndims = ndims
        
        # First convolution layer
        if ndims == 2:
            self.first_conv = nn.Conv2d(1, num_filters, 3, padding=1)
        else:
            self.first_conv = nn.Conv3d(1, num_filters, 3, padding=1)
            
        # Residual blocks
        res_blocks = []
        for _ in range(num_res_blocks):
            res_blocks.append(ResBlock(num_filters, 3, norm_type='instancenorm', apply_norm=False, ndims=ndims))
        self.res_blocks = nn.Sequential(*res_blocks)
        
        # Final feature convolution
        if ndims == 2:
            self.final_feature_conv = nn.Conv2d(num_filters, num_filters, 3, padding=1)
        else:
            self.final_feature_conv = nn.Conv3d(num_filters, num_filters, 3, padding=1)
        
        # Upsampling layers
        if scale == 2:
            self.upsample = nn.ModuleList([
                Upsample(2, num_filters, norm_type='instancenorm', apply_norm=False, ndims=ndims),
            ])
        elif scale == 3:
            self.upsample = nn.ModuleList([
                Upsample(3, num_filters, norm_type='instancenorm', apply_norm=False, ndims=ndims),
            ])
        elif scale == 4:
            self.upsample = nn.ModuleList([
                Upsample(2, num_filters, norm_type='instancenorm', apply_norm=False, ndims=ndims),
                Upsample(2, num_filters, norm_type='instancenorm', apply_norm=False, ndims=ndims)
            ])
        elif scale == 8:
            self.upsample = nn.ModuleList([
                Upsample(2, num_filters, norm_type='instancenorm', apply_norm=False, ndims=ndims),
                Upsample(2, num_filters, norm_type='instancenorm', apply_norm=False, ndims=ndims),
                Upsample(2, num_filters, norm_type='instancenorm', apply_norm=False, ndims=ndims)
            ])
            
        # Final reconstruction convolution
        if ndims == 2:
            self.final_conv = nn.Conv2d(num_filters, 1, 3, padding=1)
        else:
            self.final_conv = nn.Conv3d(num_filters, 1, 3, padding=1)
            
        # Output activation
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        x = self.first_conv(x)
        residual = x
        
        x = self.res_blocks(x)
        x = self.final_feature_conv(x)
        x = x + residual
        
        for up in self.upsample:
            x = up(x)
            
        x = self.final_conv(x)
        return self.tanh(x)

class EDSR1D(nn.Module):
    """EDSR Super-Resolution Network with 1D upsampling"""
    def __init__(self, scale, num_filters=64, num_res_blocks=8, ndims=2):
        super(EDSR1D, self).__init__()
        self.ndims = ndims
        
        # First convolution layer
        if ndims == 2:
            self.first_conv = nn.Conv2d(1, num_filters, 3, padding=1)
        else:
            self.first_conv = nn.Conv3d(1, num_filters, 3, padding=1)
            
        # Residual blocks
        res_blocks = []
        for _ in range(num_res_blocks):
            res_blocks.append(ResBlock(num_filters, 3, norm_type='instancenorm', apply_norm=False, ndims=ndims))
        self.res_blocks = nn.Sequential(*res_blocks)
        
        # Final feature convolution
        if ndims == 2:
            self.final_feature_conv = nn.Conv2d(num_filters, num_filters, 3, padding=1)
        else:
            self.final_feature_conv = nn.Conv3d(num_filters, num_filters, 3, padding=1)
        
        # Upsampling layers
        if scale == 2:
            self.upsample = nn.ModuleList([
                Upsample1D(2, num_filters, norm_type='instancenorm', apply_norm=False, ndims=ndims),
            ])
        elif scale == 3:
            self.upsample = nn.ModuleList([
                Upsample1D(3, num_filters, norm_type='instancenorm', apply_norm=False, ndims=ndims),
            ])
        elif scale == 4:
            self.upsample = nn.ModuleList([
                Upsample1D(2, num_filters, norm_type='instancenorm', apply_norm=False, ndims=ndims),
                Upsample1D(2, num_filters, norm_type='instancenorm', apply_norm=False, ndims=ndims)
            ])
        elif scale == 8:
            self.upsample = nn.ModuleList([
                Upsample1D(2, num_filters, norm_type='instancenorm', apply_norm=False, ndims=ndims),
                Upsample1D(2, num_filters, norm_type='instancenorm', apply_norm=False, ndims=ndims),
                Upsample1D(2, num_filters, norm_type='instancenorm', apply_norm=False, ndims=ndims)
            ])
            
        # Final reconstruction convolution
        if ndims == 2:
            self.final_conv = nn.Conv2d(num_filters, 1, 3, padding=1)
        else:
            self.final_conv = nn.Conv3d(num_filters, 1, 3, padding=1)
            
        # Output activation
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        x = self.first_conv(x)
        residual = x
        
        x = self.res_blocks(x)
        x = self.final_feature_conv(x)
        x = x + residual
        
        for up in self.upsample:
            x = up(x)
            
        x = self.final_conv(x)
        return self.tanh(x)

class DualSRDataset(torch.utils.data.Dataset):
    """
    Custom dataset for Dual Super Resolution that handles the scaling factor mismatch
    between HR and LR datasets.
    
    This preserves all HR slices while ensuring proper pairing with LR slices.
    """
    def __init__(self, hr_tensor, lr_tensor, scale):
        """
        Args:
            hr_tensor (torch.Tensor): High resolution tensor with shape [B*scale, C, H, W]
            lr_tensor (torch.Tensor): Low resolution tensor with shape [B, C, H/scale, W/scale]
            scale (int): Upscaling factor
        """
        self.hr_tensor = hr_tensor
        self.lr_tensor = lr_tensor
        self.scale = scale
        
        # Verify tensor shapes match our requirements
        self.hr_batch_size = hr_tensor.shape[0]
        self.lr_batch_size = lr_tensor.shape[0]
        
        assert self.hr_batch_size == self.lr_batch_size * scale, \
            f"HR batch size ({self.hr_batch_size}) must be scale ({scale}) times LR batch size ({self.lr_batch_size})"
    
    def __len__(self):
        return self.lr_batch_size
    
    def __getitem__(self, idx):
        """
        For each LR slice, return:
        - The LR slice
        - The corresponding scale HR slices
        
        Returns:
            tuple: (lr_slice, hr_slices)
        """
        lr_slice = self.lr_tensor[idx]
        
        # Get the corresponding HR slices (there are 'scale' of them for each LR slice)
        hr_start_idx = idx * self.scale
        hr_end_idx = hr_start_idx + self.scale
        hr_slices = self.hr_tensor[hr_start_idx:hr_end_idx]
        
        return lr_slice, hr_slices

# Fixed StraightThroughRound implementation
class StraightThroughRound(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)  # Save input for backward
        return torch.round(x)
        
    @staticmethod
    def backward(ctx, grad_output):
        # Pass gradient through unchanged
        return grad_output

In [5]:
def _gaussian_kernel(kernel_size, sigma, n_channels, dtype):
    """Create a Gaussian kernel for blurring"""
    x = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=dtype)
    g = torch.exp(-(x ** 2) / (2 * sigma ** 2))
    g_norm2d = torch.sum(g) ** 2
    g_kernel = torch.outer(g, g) / g_norm2d
    g_kernel = g_kernel.reshape(1, 1, kernel_size, kernel_size)
    g_kernel = g_kernel.repeat(n_channels, 1, 1, 1)
    return g_kernel

def apply_blur(img, kernel_size, sigma, n_channel):
    """Apply Gaussian blur to image"""
    blur = _gaussian_kernel(kernel_size, sigma, n_channel, img.dtype).to(img.device)
    img = F.conv2d(img, blur, padding=kernel_size//2, groups=n_channel)
    return img

def augment_data(image):
    """Apply data augmentation to input image"""
    # Get the original device
    device = image.device
    
    # Convert to PyTorch tensor if not already
    if not isinstance(image, torch.Tensor):
        image = torch.from_numpy(image).float().to(device)
        
    # Make sure random numbers are created on the same device as the image
    cont_factor = (torch.rand(1, device=device)*2-1)*0.2+1
    bright_factor = (torch.rand(1, device=device)*2-1)*0.2+1
    
    image = image * bright_factor
    
    # Calculate mean across spatial dimensions
    if len(image.shape) == 4:  # [B, C, H, W]
        mean = torch.mean(image, dim=(2, 3), keepdim=True)
    else:  # [B, C, D, H, W]
        mean = torch.mean(image, dim=(2, 3, 4), keepdim=True)
    
    image = (image - mean) * cont_factor + mean
    image = torch.clamp(image, -1, 1)
    
    # Ensure output is on the same device as input
    return image

def calculate_psnr(img1, img2, max_val=2.0):
    """Calculate PSNR between two images"""
    # Ensure both tensors are on the same device
    if img1.device != img2.device:
        img2 = img2.to(img1.device)
    
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return torch.tensor(100.0, device=img1.device)
    return 20 * torch.log10(torch.tensor(max_val, device=img1.device) / torch.sqrt(mse))

def move_to_device(tensor, device):
    """Move tensor to specified device if it's not already there"""
    if tensor.device != device:
        return tensor.to(device)
    return tensor

def process_image_in_chunks(model, image, args, device, dim='xy'):
    """
    Process a large image in chunks to avoid memory issues
    
    Args:
        model: Neural network model
        image: Input image as tensor [B, C, H, W]
        args: Configuration arguments
        device: Computation device (CPU/GPU)
        dim: Dimension to process ('xy' or 'yz')
        
    Returns:
        Processed image
    """
    with torch.no_grad():
        batch, channel, height, width = image.shape
        
        # Set chunk size and overlap
        chunk_size = args.chunk_size
        overlap = args.overlap
        
        # For very small images, just process directly
        if height <= chunk_size and width <= chunk_size:
            with torch.no_grad():
                return model(image)
        
        # Prepare output tensor (with higher resolution in appropriate dimension)
        if dim == 'xy':
            output = torch.zeros(batch, channel, height*args.scale, width*args.scale, 
                                device='cpu', dtype=torch.float32)
        else:  # 'yz' - only upscale first dimension
            output = torch.zeros(batch, channel, height*args.scale, width, 
                                device='cpu', dtype=torch.float32)
        
        # Define a weight map for blending overlapping regions
        weight_map = torch.zeros_like(output, device='cpu')
        
        # Process image in horizontal chunks
        for i in range(0, width, chunk_size-overlap):
            # Handle last chunk boundary
            end_i = min(i + chunk_size, width)
            start_i = max(0, end_i - chunk_size)
            
            # Process image in vertical chunks
            for j in range(0, height, chunk_size-overlap):
                # Handle last chunk boundary
                end_j = min(j + chunk_size, height)
                start_j = max(0, end_j - chunk_size)
                
                # Extract chunk
                chunk = image[:, :, start_j:end_j, start_i:end_i].to(device)
                
                # Process chunk
                with torch.no_grad():
                    processed_chunk = model(chunk)
                
                # Calculate output coordinates with upscaling
                if dim == 'xy':
                    out_start_j = start_j * args.scale
                    out_end_j = end_j * args.scale
                    out_start_i = start_i * args.scale
                    out_end_i = end_i * args.scale
                else:  # 'yz' - only upscale in first dimension
                    out_start_j = start_j * args.scale
                    out_end_j = end_j * args.scale
                    out_start_i = start_i
                    out_end_i = end_i
                
                # Create blending weights for smooth transitions
                # Higher weight in the center, lower at edges
                blend = torch.ones((processed_chunk.size(0), 
                                    processed_chunk.size(1),
                                    out_end_j - out_start_j,
                                    out_end_i - out_start_i), 
                                  device='cpu')
                
                # Apply edge tapering for horizontal boundaries
                if start_i > 0:  # Left edge
                    for x in range(overlap * args.scale):
                        weight = 0.5 * (1 - np.cos(np.pi * x / (overlap * args.scale)))
                        blend[:, :, :, x] = weight
                
                if end_i < width:  # Right edge
                    for x in range(overlap * args.scale):
                        weight = 0.5 * (1 - np.cos(np.pi * (overlap * args.scale - x) / (overlap * args.scale)))
                        blend[:, :, :, -(x+1)] = weight
                
                # Apply edge tapering for vertical boundaries
                if start_j > 0:  # Top edge
                    for y in range(overlap * args.scale):
                        weight = 0.5 * (1 - np.cos(np.pi * y / (overlap * args.scale)))
                        blend[:, :, y, :] = blend[:, :, y, :] * weight
                
                if end_j < height:  # Bottom edge
                    for y in range(overlap * args.scale):
                        weight = 0.5 * (1 - np.cos(np.pi * (overlap * args.scale - y) / (overlap * args.scale)))
                        blend[:, :, -(y+1), :] = blend[:, :, -(y+1), :] * weight
                
                # Move processed chunk to CPU
                processed_chunk = processed_chunk.to('cpu')
                
                # Apply blending
                output[:, :, out_start_j:out_end_j, out_start_i:out_end_i] += processed_chunk * blend
                weight_map[:, :, out_start_j:out_end_j, out_start_i:out_end_i] += blend
        
        # Normalize by the weight map to get the final output
        # Avoid division by zero
        weight_map = torch.clamp(weight_map, min=1e-8)
        output = output / weight_map
        
        return output
    
def authenticate_wandb(api_key=None, entity=None):
    """
    Authenticate with wandb using API key
    
    Args:
        api_key: Your wandb API key
        entity: Optional wandb username or team name
        
    Returns:
        bool: True if authentication successful, False otherwise
    """
    try:
        if api_key:
            print("Authenticating with wandb using API key...")
            wandb.login(key=api_key)
        else:
            print("No API key provided. Using default wandb login...")
            wandb.login()
            
        # Verify authentication
        if wandb.api.api_key:
            print(f"Successfully authenticated with wandb{' as ' + entity if entity else ''}!")
            return True
        else:
            print("Failed to authenticate with wandb.")
            return False
    except Exception as e:
        print(f"Error authenticating with wandb: {str(e)}")
        return False

def createTrainingCubes2(args, HR, LR, batchsize, cropsize, scale):
    """
    Create training cubes by extracting 3D blocks from HR/LR volumes in XY, YZ, and XZ planes,
    then unrolling them along the batch dimension.
    """
    # Allocate output:
    #  - LR has shape [batchsize * itersPerEpoch, cropsize,         cropsize,         1]
    #  - HR has shape [batchsize * itersPerEpoch * scale, cropsize * scale, cropsize * scale, 1]
    batchLR = np.zeros([batchsize * args.itersPerEpoch, cropsize, cropsize, 1], dtype=np.float32)
    batchHR = np.zeros([batchsize * args.itersPerEpoch * scale,
                        cropsize * scale,
                        cropsize * scale,
                        1], dtype=np.float32)
    
    # We will keep two running indices:
    #  - n  for LR slices in the batch dimension
    #  - n2 for HR slices in the batch dimension
    n = 0
    n2 = 0

    for i in tqdm(range(args.itersPerEpoch), desc="Creating Training Cubes"):
        # Cycle between xy, yz, and xz
        if np.mod(i, 3) == 0:
            # XY-style block:
            #    block shape:    [batchsize, cropsize,         cropsize,         1]
            #    blockHR shape: [batchsize*scale, cropsize*scale, cropsize*scale, 1]
            x = int(np.floor(np.random.rand() * (LR.shape[0] - batchsize)))
            y = int(np.floor(np.random.rand() * (LR.shape[1] - cropsize)))
            z = int(np.floor(np.random.rand() * (LR.shape[2] - cropsize)))
            
            block = np.expand_dims(LR[x:x + batchsize,
                                      y:y + cropsize,
                                      z:z + cropsize], axis=3)
            
            blockHR = np.expand_dims(HR[x * scale : x * scale + batchsize * scale,
                                        y * scale : y * scale + cropsize * scale,
                                        z * scale : z * scale + cropsize * scale], axis=3)

        elif np.mod(i, 3) == 1:
            # YZ-style block:
            #    block shape:    [batchsize, cropsize, cropsize, 1] (after transpose)
            #    blockHR shape: [batchsize*scale, cropsize*scale, cropsize*scale, 1]
            x = int(np.floor(np.random.rand() * (LR.shape[0] - cropsize)))
            y = int(np.floor(np.random.rand() * (LR.shape[1] - cropsize)))
            z = int(np.floor(np.random.rand() * (LR.shape[2] - batchsize)))
            
            block = np.expand_dims(LR[x:x + cropsize,
                                      y:y + cropsize,
                                      z:z + batchsize], axis=3)
            
            blockHR = np.expand_dims(HR[x * scale : x * scale + cropsize * scale,
                                         y * scale : y * scale + cropsize * scale,
                                         z * scale : z * scale + batchsize * scale], axis=3)
            
            # Transpose so that the batch dimension is first
            block = np.transpose(block,   [2, 0, 1, 3])   # shape → [batchsize, cropsize, cropsize, 1]
            blockHR = np.transpose(blockHR, [2, 0, 1, 3]) # shape → [batchsize*scale, cropsize*scale, cropsize*scale, 1]

        elif np.mod(i, 3) == 2:
            # XZ-style block:
            #    block shape:    [batchsize, cropsize, cropsize, 1] (after transpose)
            #    blockHR shape: [batchsize*scale, cropsize*scale, cropsize*scale, 1]
            x = int(np.floor(np.random.rand() * (LR.shape[0] - cropsize)))
            y = int(np.floor(np.random.rand() * (LR.shape[1] - batchsize)))
            z = int(np.floor(np.random.rand() * (LR.shape[2] - cropsize)))
            
            block = np.expand_dims(LR[x:x + cropsize,
                                      y:y + batchsize,
                                      z:z + cropsize], axis=3)
            
            blockHR = np.expand_dims(HR[x * scale : x * scale + cropsize * scale,
                                         y * scale : y * scale + batchsize * scale,
                                         z * scale : z * scale + cropsize * scale], axis=3)
            
            # Transpose so that the batch dimension is first
            block = np.transpose(block,   [1, 0, 2, 3])   # shape → [batchsize, cropsize, cropsize, 1]
            blockHR = np.transpose(blockHR, [1, 0, 2, 3]) # shape → [batchsize*scale, cropsize*scale, cropsize*scale, 1]
        
        # Write these blocks into batchLR / batchHR
        batchLR[n : n + batchsize] = block / 127.5 - 1.0
        batchHR[n2 : n2 + batchsize * scale] = blockHR / 127.5 - 1.0
        
        n  += batchsize
        n2 += batchsize * scale

    return batchHR, batchLR

def test_memory_efficient_3d(test_files: List[str], 
                          generator_sr: nn.Module, 
                          generator_src: nn.Module,
                          args, device):
    # Ensure output directory exists
    if not os.path.exists(os.path.join(args.test_save_dir, args.modelName)):
        os.makedirs(os.path.join(args.test_save_dir, args.modelName), exist_ok=True)
    
    # Step 1: XY super-resolution for all slices
    print("Starting XY super-resolution pass...")
    xy_processed_slices = []
    
    # Use tqdm for processing slices
    for i, file_path in tqdm(enumerate(test_files), total=len(test_files), desc="XY Pass"):
        # Load image
        slice_z = imageio.imread(file_path)
        slice_z = (slice_z / 127.5) - 1  # Normalize to [-1, 1]
        
        # Convert to tensor
        slice_z = torch.tensor(slice_z, dtype=torch.float32)
        slice_z = slice_z.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions
        slice_z = slice_z.to(device)
        
        # Process through first SR network in chunks if needed
        with torch.no_grad():
            xy_result = process_image_in_chunks(generator_sr, slice_z, args, device, 'xy')
            
            # Quantize to 8-bit
            xy_result = (xy_result + 1) * 127.5
            xy_result = torch.round(xy_result).clamp(0, 255).to(torch.uint8)
            xy_result = xy_result / 127.5 - 1
            
            # Store result
            xy_processed_slices.append(xy_result.squeeze(0).squeeze(0).cpu().numpy())
    
    # Get output dimensions
    out_height = xy_processed_slices[0].shape[0]
    out_width = xy_processed_slices[0].shape[1]
    num_slices = len(xy_processed_slices)
    
    # Create output volume for storing results
    output_volume = np.zeros((num_slices * args.scale, out_height, out_width), dtype=np.uint8)
    
    # Step 2: YZ super-resolution
    print("Starting YZ super-resolution pass...")
    
    # Process in batches to manage memory
    batch_size = min(50, out_width)  # Process 50 vertical slices at once
    
    # Use tqdm for the y batches
    for y_start in tqdm(range(0, out_height, batch_size), desc="YZ Pass"):
        y_end = min(y_start + batch_size, out_height)
        
        # Extract batch of vertical slices
        vertical_slices_batch = []
        for y in range(y_start, y_end):
            # Extract vertical slice from all XY-processed slices for this row
            vertical_slice = np.stack([slice_data[y, :] for slice_data in xy_processed_slices])
            vertical_slices_batch.append(vertical_slice)
        
        # Stack slices into a batch - [B, Z, X] where B is the batch of Y positions
        vertical_batch = np.stack(vertical_slices_batch, axis=0)  
        vertical_batch_tensor = torch.tensor(vertical_batch, dtype=torch.float32)
        vertical_batch_tensor = vertical_batch_tensor.unsqueeze(1)  # Add channel dimension [B, C, Z, X]
        vertical_batch_tensor = vertical_batch_tensor.to(device)
        
        # Process through YZ SR network
        with torch.no_grad():
            # Process each slice in the batch individually to avoid size mismatch
            processed_slices = []
            for b in range(vertical_batch_tensor.size(0)):
                # Extract single slice [1, 1, Z, X]
                single_slice = vertical_batch_tensor[b:b+1]
                
                # Process this slice
                processed_slice = process_image_in_chunks(
                    generator_src, 
                    single_slice, 
                    args, 
                    device,
                    'yz'
                )
                processed_slices.append(processed_slice)
            
            # Concatenate results [B, C, Z*scale, X]
            processed_batch = torch.cat(processed_slices, dim=0)
            
            # Get output as numpy array and normalize to [0, 255]
            processed_batch = (processed_batch.cpu().numpy() + 1) * 127.5
            processed_batch = np.round(processed_batch).astype(np.uint8)
            
            # Insert results into output volume
            for i, y in enumerate(range(y_start, y_end)):
                output_volume[:, y, :] = processed_batch[i, 0]
        
        # Clear GPU memory
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Save full 3D volume
    output_path = f'{args.test_save_dir}/{args.modelName}/full_volume.tif'
    print(f"Saving complete 3D volume with shape [D={output_volume.shape[0]}, H={output_volume.shape[1]}, W={output_volume.shape[2]}] to {output_path}")
    tifffile.imwrite(output_path, output_volume)
    
    return output_volume

def process_entire_volume(test_files: List[str], 
                          generator_sr: nn.Module, 
                          generator_src: nn.Module,
                          args, device):
    """
    Process 3D volumes without chunking when sufficient memory is available
    
    Args:
        test_files: List of image file paths to process
        generator_sr: XY plane super-resolution model
        generator_src: YZ plane super-resolution model
        args: Configuration arguments
        device: Computation device (CPU/GPU)
    """
    # Ensure output directory exists
    if not os.path.exists(os.path.join(args.test_save_dir, args.modelName)):
        os.makedirs(os.path.join(args.test_save_dir, args.modelName), exist_ok=True)
    
    # Step 1: XY super-resolution for all slices
    print("Starting XY super-resolution pass...")
    
    # Load all XY slices at once
    slices_z = []
    for i, file_path in tqdm(enumerate(test_files), total=len(test_files), desc="Loading XY slices"):
        slice_z = imageio.imread(file_path)
        slices_z.append(slice_z)
    
    # Stack all slices into a single volume [Z, H, W]
    xy_volume = np.stack(slices_z, axis=0)
    print(f"Loaded volume shape: {xy_volume.shape}")
    
    # Normalize to [-1, 1]
    xy_volume_norm = (xy_volume / 127.5) - 1
    
    # Split into smaller batches for GPU processing (based on GPU memory)
    # but process all at once on CPU if needed
    batch_size = 32 if torch.cuda.is_available() else len(xy_volume_norm)
    xy_processed_slices = []
    
    print("Processing XY slices...")
    with torch.no_grad():
        for i in tqdm(range(0, len(xy_volume_norm), batch_size), desc="XY Pass"):
            # Get batch of slices
            batch_slices = xy_volume_norm[i:i+batch_size]
            
            # Convert to tensor [B, 1, H, W]
            batch_tensor = torch.tensor(batch_slices, dtype=torch.float32).unsqueeze(1)
            batch_tensor = batch_tensor.to(device)
            
            # Process through SR network
            xy_result = generator_sr(batch_tensor)
            
            # Move result to CPU before quantization
            xy_result = xy_result.cpu()
            
            # Store results
            for j in range(xy_result.shape[0]):
                xy_processed_slices.append(xy_result[j, 0].numpy())
    
    # Get output dimensions
    out_height = xy_processed_slices[0].shape[0]
    out_width = xy_processed_slices[0].shape[1]
    num_slices = len(xy_processed_slices)
    
    # Stack into a volume [Z, H, W]
    xy_volume_sr = np.stack(xy_processed_slices, axis=0)
    
    # Quantize to 8-bit representation (between -1 and 1)
    xy_volume_sr = np.round((xy_volume_sr + 1) * 127.5) / 127.5 - 1
    
    # Create output volume for final results
    output_volume = np.zeros((num_slices * args.scale, out_height, out_width), dtype=np.uint8)
    
    # Step 2: YZ super-resolution - process entire volume at once
    print("Starting YZ super-resolution pass...")
    print(f"XY super-resolved volume shape: {xy_volume_sr.shape}")
    
    # Process entire volume row by row along Y dimension
    with torch.no_grad():
        for y in tqdm(range(out_height), desc="YZ Pass"):
            # Extract vertical slice from all XY-processed slices for this row [Z, W]
            vertical_slice = xy_volume_sr[:, y, :]
            
            # Prepare for model - add batch and channel dimensions [1, 1, Z, W]
            vertical_tensor = torch.tensor(vertical_slice, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
            vertical_tensor = vertical_tensor.to(device)
            
            # Process through YZ SR network
            vertical_sr = generator_src(vertical_tensor)
            
            # Get output and convert to uint8
            vertical_sr_np = (vertical_sr.cpu().numpy() + 1) * 127.5
            vertical_sr_np = np.round(vertical_sr_np).astype(np.uint8)
            
            # Insert into output volume
            output_volume[:, y, :] = vertical_sr_np[0, 0]
            
            # Clear GPU memory if needed
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # Save full 3D volume
    output_path = f'{args.test_save_dir}/{args.modelName}/full_volume_direct.tif'
    print(f"Saving complete 3D volume with shape [D={output_volume.shape[0]}, H={output_volume.shape[1]}, W={output_volume.shape[2]}] to {output_path}")
    tifffile.imwrite(output_path, output_volume)
    
    return output_volume

def quantize(x):
    x = (x + 1) * 127.5
    x = StraightThroughRound.apply(x)
    x = torch.clamp(x, 0, 255)
    return x / 127.5 - 1

In [6]:
def main():
    # Define arguments
    args = Args()

    # Create checkpoint directory
    training_dir = f"{args.checkpoint_dir}/{args.modelName}"
    if not os.path.exists(training_dir):
        os.makedirs(training_dir, exist_ok=True)
    
    # Set device
    device = torch.device(f'cuda:{args.gpuIDs}' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize models
    generator_sr = EDSR(scale=args.scale, num_filters=args.ngsrf, num_res_blocks=args.numResBlocks, ndims=2)
    generator_src = EDSR1D(scale=args.scale, num_filters=args.ngsrf//2, num_res_blocks=args.numResBlocks//2, ndims=2)
    
    # Move models to device
    generator_sr = generator_sr.to(device)
    generator_src = generator_src.to(device)
    
    # Initialize optimizers
    optimizer_gen_sr = optim.Adam(generator_sr.parameters(), lr=args.lr)
    optimizer_gen_src = optim.Adam(generator_src.parameters(), lr=args.lr)
    
    # Initialize CosineAnnealingWarmRestarts scheduler
    scheduler_gen_sr = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer_gen_sr,
        T_0=args.epoch_step,  # Initial restart interval
        T_mult=2,  # Increase T_0 by this factor after each restart
        eta_min=args.eta_min  # Minimum learning rate
    )

    scheduler_gen_src = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer_gen_src,
        T_0=args.epoch_step,  # Initial restart interval
        T_mult=2,  # Increase T_0 by this factor after each restart
        eta_min=args.eta_min  # Minimum learning rate
    )
    
    # Load checkpoints for testing or continuing training
    if args.phase == 'test':
        print(f'Loading checkpoints from {training_dir}')
        try:
            # Try to load the best model first if specified
            if args.use_best_model:
                best_model_path = f'{training_dir}/GSR-best.pth'
                if os.path.exists(best_model_path):
                    print(f"Loading best model from {best_model_path}")
                    generator_sr.load_state_dict(torch.load(best_model_path, map_location=device))
                    generator_src.load_state_dict(torch.load(f'{training_dir}/GSRC-best.pth', map_location=device))
                    print("Successfully loaded best models")
                else:
                    # Fall back to specified epoch
                    print(f"Best model not found, loading epoch {args.continueEpoch}")
                    generator_sr.load_state_dict(torch.load(f'{training_dir}/GSR-{args.continueEpoch}.pth', map_location=device))
                    generator_src.load_state_dict(torch.load(f'{training_dir}/GSRC-{args.continueEpoch}.pth', map_location=device))
                    print(f"Successfully loaded models from epoch {args.continueEpoch}")
            else:
                # Load specified epoch directly
                generator_sr.load_state_dict(torch.load(f'{training_dir}/GSR-{args.continueEpoch}.pth', map_location=device))
                generator_src.load_state_dict(torch.load(f'{training_dir}/GSRC-{args.continueEpoch}.pth', map_location=device))
                print(f"Successfully loaded models from epoch {args.continueEpoch}")
        except Exception as e:
            print(f'Could not load SR related weights: {str(e)}')
            print('Will start with fresh weights')
    elif args.continue_train:
        print(f'Loading checkpoints from {training_dir} for epoch {args.continueEpoch}')
        try:
            # Load model weights
            generator_sr.load_state_dict(torch.load(f'{training_dir}/GSR-{args.continueEpoch}.pth', map_location=device))
            generator_src.load_state_dict(torch.load(f'{training_dir}/GSRC-{args.continueEpoch}.pth', map_location=device))
            
            # Try to load optimizer state if available
            try:
                gen_sr_checkpoint = torch.load(f'{training_dir}/GSR-{args.continueEpoch}_full.pth', map_location=device)
                gen_src_checkpoint = torch.load(f'{training_dir}/GSRC-{args.continueEpoch}_full.pth', map_location=device)
                
                optimizer_gen_sr.load_state_dict(gen_sr_checkpoint['optimizer_state_dict'])
                optimizer_gen_src.load_state_dict(gen_src_checkpoint['optimizer_state_dict'])
                print("Successfully loaded optimizers")
                
                # Fast-forward scheduler to current epoch
                for _ in range(args.continueEpoch):
                    scheduler_gen_sr.step()
                    scheduler_gen_src.step()
                    
            except Exception as e:
                print(f"Could not load optimizer states: {str(e)}")
                print("Will use fresh optimizers")
                
            print("Successfully loaded model weights")
        except Exception as e:
            print(f'Could not load SR related weights: {str(e)}')
            print('Will start with fresh weights')
    
    # Training phase
    if args.phase == 'train':
        # Create output directories
        right_now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        val_out_dir = args.dataset_dir.split('/')[-1] if '/' in args.dataset_dir else args.dataset_dir.split('\\\\')[-1]
        train_output_dir = f'./training_outputs/{right_now}-pytorch-{val_out_dir}-{args.modelName}/'
        os.makedirs(train_output_dir, exist_ok=True)
        
        # Authenticate with wandb before initializing
        if not authenticate_wandb(args.wandb_api_key, args.wandb_entity):
            print("Warning: wandb authentication failed. Continuing without experiment tracking.")
        
        # Initialize wandb for tracking
        wandb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=f"{args.modelName}-{right_now}",
            config={
                "model": args.modelName,
                "scale": args.scale,
                "learning_rate": args.lr,
                "epochs": args.epoch,
                "batch_size": args.batch_size,
                "patch_size": args.patch_size,
                "num_res_blocks": args.numResBlocks,
                "filters_sr": args.ngsrf,
                "filters_src": args.ngsrf//2,
                "dataset": val_out_dir,
                "scheduler": "CosineAnnealingWarmRestarts",
                "scheduler_T0": args.epoch_step,
                "scheduler_T_mult": 2,
                "scheduler_eta_min": args.eta_min
            }
        )
        
        # Log model architecture as a string
        wandb.run.summary["model_sr"] = str(generator_sr)
        wandb.run.summary["model_src"] = str(generator_src)

        # Initialize metric tracking
        train_losses = {
            'epoch': [],
            'sr_xy_loss': [],
            'sr_yz_loss': [],
            'total_loss': [],
            'global_step': []
        }

        val_metrics = {
            'epoch': [],
            'psnr_xy': [],
            'psnr_xyz': [],
            'l1_loss': [],
            'global_step': []
        }
        
        # Initialize batch-level losses for more detailed tracking
        batch_losses = []
        
        # Initialize best model tracking
        best_model = {
            'epoch': 0,
            'val_psnr_xyz': 0.0,
            'val_psnr_xy': 0.0,
            'val_l1_loss': float('inf')
        }

        print('2D/3D training specified, datasets will be randomly mini-batched per epoch')
        print('2D/3D dataset and training -> data will be fully preloaded into RAM')
        
        # Load training data
        BC_loc = glob(os.path.join(args.dataset_dir, 'LR.npy'))
        if not BC_loc:
            raise FileNotFoundError(f"LR.npy not found in {args.dataset_dir}")
        print(f"Loading LR data from: {BC_loc[0]}")
        LR = np.load(BC_loc[0])
        
        HR_loc = glob(os.path.join(args.dataset_dir, 'HR.npy'))
        if not HR_loc:
            raise FileNotFoundError(f"HR.npy not found in {args.dataset_dir}")
        print(f"Loading HR data from: {HR_loc[0]}")
        HR = np.load(HR_loc[0])
        
        print(f"LR shape: {LR.shape}, HR shape: {HR.shape}")
        wandb.run.summary["dataset_lr_shape"] = str(LR.shape)
        wandb.run.summary["dataset_hr_shape"] = str(HR.shape)
        
        # Load test data if needed
        if args.valTest:
            LR_test_loc = glob(os.path.join(args.dataset_dir, 'test', '*'))
            if not LR_test_loc:
                print(f"Warning: No test data found in {os.path.join(args.dataset_dir, 'test')}")
            else:
                LR_test = np.load(LR_test_loc[0])
                LR_test = torch.tensor(LR_test, dtype=torch.float32).unsqueeze(1)  # Add channel dimension
                LR_test = LR_test.to(device)
        
        start_time = time.time()
        global_step = 0  # For wandb step tracking
        
        for epoch in range(args.epoch):
            # Adjust batch and crop size
            total_per_batch_voxels = args.patch_size * args.patch_size * args.batch_size
            min_per_dim_size = args.batch_size
            max_per_dim_size = args.patch_size
            batch_size_this_epoch = int(np.floor(np.random.rand()*(max_per_dim_size-min_per_dim_size))+min_per_dim_size)
            patch_size_this_epoch = int(np.floor(np.sqrt(total_per_batch_voxels/batch_size_this_epoch)))
            
            print(f'Reading dataset, block size this epoch: {batch_size_this_epoch} x {patch_size_this_epoch} x {patch_size_this_epoch} -> {args.scale}x')
            
            # Create training data batches
            real_HR_batches, real_BC_batches = createTrainingCubes2(args, HR, LR, batch_size_this_epoch, patch_size_this_epoch, args.scale)
            
            # Convert to PyTorch tensors
            HR_dataset = torch.from_numpy(real_HR_batches).permute(0, 3, 1, 2).float()  # [B, C, H, W]
            LR_dataset = torch.from_numpy(real_BC_batches).permute(0, 3, 1, 2).float()  # [B, C, H, W]
            
            # Create validation subset
            HR_val_dataset = HR_dataset[:args.valNum*batch_size_this_epoch*args.scale]
            LR_val_dataset = LR_dataset[:args.valNum*batch_size_this_epoch]

            # Prepare data loaders
            train_dataset = DualSRDataset(HR_dataset, LR_dataset, args.scale)
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=batch_size_this_epoch,
                shuffle=False
            )
            
            # Get learning rate
            current_lr = optimizer_gen_sr.param_groups[0]['lr']
            print(f'Learning Rate: {current_lr:.6e}')

            # Log learning rate to wandb
            wandb.log({
                "learning_rate": current_lr,
                "epoch": epoch+1
            }, step=global_step)
            
            # Training loop variables
            tot_g_sr_xy_loss = 0
            tot_g_sr_yz_loss = 0
            num_batches = 0
            last_time = time.time()
            
            # Set models to training mode
            generator_sr.train()
            generator_src.train()
            
            # Iterator for cycling through data multiple times if needed
            train_iter = iter(train_loader)
            
            # Main training loop
            pbar = tqdm(total=args.itersPerEpoch * args.iterCyclesPerEpoch, desc=f"Epoch {epoch+1}/{args.epoch}")
            
            while num_batches < args.itersPerEpoch * args.iterCyclesPerEpoch:
                # Get next batch or reset iterator
                try:
                    B_xy, C_xyz_group = next(train_iter)
                except StopIteration:
                    train_iter = iter(train_loader)
                    B_xy, C_xyz_group = next(train_iter)

                # Move data to device
                B_xy = move_to_device(B_xy, device)
                C_xyz_group = move_to_device(C_xyz_group, device)
                
                # Data augmentation if enabled
                if args.augFlag:
                    B_xy = augment_data(B_xy)
                
                # Zero gradients
                optimizer_gen_sr.zero_grad()
                optimizer_gen_src.zero_grad()
                
                # Take the first HR slice (from the group) for XY model
                # We can reshape the batch dimension to match what C_xyz_downsampled expects
                C_xyz = C_xyz_group.view(-1, C_xyz_group.size(2), C_xyz_group.size(3), C_xyz_group.size(4))
                C_xy_downsampled = F.interpolate(C_xyz.permute(1, 2, 0, 3), size=(C_xyz.size(0)//args.scale, C_xyz.size(2)), mode='bicubic', align_corners=False)
                C_xy_downsampled = C_xy_downsampled.permute(2, 0, 1, 3)
                
                # Forward pass XY SR
                SR_xy = generator_sr(B_xy)
                loss_sr_xy = F.l1_loss(SR_xy, C_xy_downsampled)

                # Quantize to 8-bit
                SR_xy_quantized = quantize(SR_xy)
                
                # SR_xy has shape [B, C, H, W]
                SR_xy_t = SR_xy_quantized.permute(2, 1, 0, 3).to(device)  # [H, C, B, W]
                
                # Forward pass YZ SR
                SR_xyz = generator_src(SR_xy_t)
                C_xyz_t = C_xyz.permute(2, 1, 0, 3)  # [H, C, B, W]
                loss_sr_yz = F.l1_loss(SR_xyz, C_xyz_t)
                
                # Combined loss
                total_loss = loss_sr_xy + loss_sr_yz
                total_loss.backward()
                
                # Update parameters
                optimizer_gen_sr.step()
                optimizer_gen_src.step()
                
                # Get loss values
                loss_sr_xy_val = loss_sr_xy.item()
                loss_sr_yz_val = loss_sr_yz.item()
                
                # Update counters and stats
                tot_g_sr_xy_loss += loss_sr_xy_val
                tot_g_sr_yz_loss += loss_sr_yz_val
                num_batches += 1
                global_step += 1
                
                # Store batch-level losses
                batch_losses.append({
                    'epoch': epoch+1,
                    'batch': num_batches,
                    'sr_xy_loss': loss_sr_xy_val,
                    'sr_yz_loss': loss_sr_yz_val,
                    'total_loss': loss_sr_xy_val + loss_sr_yz_val,
                    'global_step': global_step,
                    'lr': current_lr
                })
                
                # Log metrics to wandb every 10 batches
                if num_batches % 10 == 0:
                    wandb.log({
                        "train/sr_xy_loss": loss_sr_xy_val,
                        "train/sr_yz_loss": loss_sr_yz_val,
                        "train/total_loss": loss_sr_xy_val + loss_sr_yz_val,
                        "train/iteration_time": time.time() - last_time,
                        "train/iterations_per_second": 1.0 / (time.time() - last_time if time.time() > last_time else 1e-5),
                    }, step=global_step)
                
                current_time = time.time()
                
                pbar.set_postfix({
                    'Time': f"{current_time-start_time:.2f}s", 
                    'Speed': f"{1/(current_time-last_time if current_time > last_time else 1e-5):.2f} it/s", 
                    'GSRxyL': f"{loss_sr_xy_val:.4f}", 
                    'GSRyzL': f"{loss_sr_yz_val:.4f}"
                })
                pbar.update(1)
                
                last_time = current_time
                
                # Break if reached target iterations
                if num_batches >= args.itersPerEpoch * args.iterCyclesPerEpoch:
                    break

            pbar.close()
            
            # Calculate epoch statistics
            tot_g_sr_xy_loss /= num_batches
            tot_g_sr_yz_loss /= num_batches
            print(f'Mean Epoch Performance: GSRxyL: {tot_g_sr_xy_loss:.4f}, GSRyzL: {tot_g_sr_yz_loss:.4f}')
            
            # Update epoch-level metrics
            train_losses['epoch'].append(epoch+1)
            train_losses['sr_xy_loss'].append(tot_g_sr_xy_loss)
            train_losses['sr_yz_loss'].append(tot_g_sr_yz_loss)
            train_losses['total_loss'].append(tot_g_sr_xy_loss + tot_g_sr_yz_loss)
            train_losses['global_step'].append(global_step)
            
            # Log epoch statistics to wandb
            wandb.log({
                "epoch/sr_xy_loss": tot_g_sr_xy_loss,
                "epoch/sr_yz_loss": tot_g_sr_yz_loss,
                "epoch/total_loss": tot_g_sr_xy_loss + tot_g_sr_yz_loss,
                "epoch": epoch+1
            }, step=global_step)
            
            # Validation and visualization
            if np.mod(epoch+1, args.print_freq) == 0 or epoch == 0:
                # Create epoch output directory
                
                
                # Set models to evaluation mode
                generator_sr.eval()
                generator_src.eval()
                
                val_psnr_xy = 0.0
                val_psnr_xyz = 0.0
                val_l1_loss_total = 0.0
                num_test_batches = 0
                

                # Validation loop
                val_pbar = tqdm(total=args.valNum, desc="Validation")
                with torch.no_grad():
                    for i in range(0, len(HR_val_dataset), batch_size_this_epoch):
                        if i + batch_size_this_epoch > len(HR_val_dataset):
                            break
                            
                        # Create a validation subset with proper structure
                        val_lr = LR_val_dataset[i:i+batch_size_this_epoch].to(device)
                        val_hr_start = i * args.scale
                        val_hr_end = val_hr_start + batch_size_this_epoch * args.scale
                        val_hr = HR_val_dataset[val_hr_start:val_hr_end].to(device)

                        # Reshape val_hr to match our dataset format [B, C, H, W]
                        val_hr_downsampled = F.interpolate(val_hr.permute(1, 2, 0, 3), size=(val_hr.size(0)//args.scale, val_hr.size(2)), mode='bicubic', align_corners=False)
                        val_hr_downsampled = val_hr_downsampled.permute(2, 0, 1, 3)
                        
                        # Forward pass XY
                        generated_xy = generator_sr(val_lr)
                        
                        # Calculate PSNR
                        psnr_xy = calculate_psnr(generated_xy, val_hr_downsampled)
                        
                        # Store results on CPU after calculation
                        generated_xy_output = generated_xy.detach().cpu().numpy()

                        # Quantize to 8-bit
                        generated_xy = (generated_xy + 1) * 127.5
                        generated_xy = torch.round(generated_xy).clamp(0, 255).to(torch.uint8)
                        generated_xy = generated_xy / 127.5 - 1

                        # Transpose for YZ dimension (ensure on device)
                        generated_xy_t = generated_xy.permute(2, 1, 0, 3).to(device)
                        

                        # Forward pass YZ
                        generated_xyz = generator_src(generated_xy_t)
                        generated_xyz = generated_xyz.permute(2, 1, 0, 3)
                        psnr_xyz = calculate_psnr(generated_xyz, val_hr)
                        generated_xyz_output = generated_xyz.detach().cpu().numpy()
                        
                        # Calculate L1 loss explicitly
                        batch_l1_loss = F.l1_loss(generated_xyz, val_hr).item()
                        
                        # Update statistics
                        val_psnr_xy += psnr_xy.item()
                        val_psnr_xyz += psnr_xyz.item()
                        val_l1_loss_total += batch_l1_loss
                        num_test_batches += 1
                        
                        val_lr_np = (val_lr.detach().cpu().numpy() + 1) * 127.5
                        val_hr_np = (val_hr.detach().cpu().numpy() + 1) * 127.5
                        val_hr_downsampled_np = (val_hr_downsampled.detach().cpu().numpy() + 1) * 127.5
                        generated_xy_np = (generated_xy_output + 1) * 127.5
                        generated_xyz_np = (generated_xyz_output + 1) * 127.5
                        
                        # Save images periodically
                        if np.mod(epoch+1, args.save_freq) == 0 or epoch == args.epoch - 1:
                            os.makedirs(f'./{train_output_dir}/epoch-{epoch+1}/', exist_ok=True)
                            
                            image_path = f'./{train_output_dir}/epoch-{epoch+1}/{num_test_batches}-LRxyz.tif'
                            tifffile.imwrite(image_path, np.squeeze(val_lr_np.astype('uint8')))

                            image_path = f'./{train_output_dir}/epoch-{epoch+1}/{num_test_batches}-HRxyz.tif'
                            tifffile.imwrite(image_path, np.squeeze(val_hr_np.astype('uint8')))

                            image_path = f'./{train_output_dir}/epoch-{epoch+1}/{num_test_batches}-HRxyz-downsampled.tif'
                            tifffile.imwrite(image_path, np.squeeze(val_hr_downsampled_np.astype('uint8')))

                            image_path = f'./{train_output_dir}/epoch-{epoch+1}/{num_test_batches}-SRxy.tif'
                            tifffile.imwrite(image_path, np.squeeze(generated_xy_np.astype('uint8')))

                            image_path = f'./{train_output_dir}/epoch-{epoch+1}/{num_test_batches}-SRxyz.tif'
                            tifffile.imwrite(image_path, np.squeeze(generated_xyz_np.astype('uint8')))
                        
                        # Log images to wandb (just first batch)
                        if num_test_batches == 1:
                            # Log sample images to wandb
                            wandb.log({
                                "images/lr": wandb.Image(np.squeeze(val_lr_np[0].astype(np.uint8)), 
                                                      caption="Low Resolution"),
                                "images/hr": wandb.Image(np.squeeze(val_hr_np[0].astype(np.uint8)), 
                                                       caption="High Resolution"),
                                "images/hr_downsampled": wandb.Image(np.squeeze(val_hr_downsampled_np[0].astype(np.uint8)), 
                                                                  caption="HR Downsampled"),
                                "images/sr_xy": wandb.Image(np.squeeze(generated_xy_np[0].astype(np.uint8)), 
                                                          caption="SR XY"),
                                "images/sr_xyz": wandb.Image(np.squeeze(generated_xyz_np[0].astype(np.uint8)), 
                                                           caption="SR XYZ"),
                            }, step=global_step)
                        
                        val_pbar.set_postfix({
                            'PSNR-SR': f"{psnr_xy.item():.4f}", 
                            'PSNR-SRC': f"{psnr_xyz.item():.4f}"
                        })
                        val_pbar.update(1)
                        
                        if num_test_batches >= args.valNum:
                            break
                val_pbar.close()
                
                # Calculate average validation metrics
                val_psnr_xy /= num_test_batches
                val_psnr_xyz /= num_test_batches
                val_l1_loss = val_l1_loss_total / num_test_batches
                
                # Update validation metrics
                val_metrics['epoch'].append(epoch+1)
                val_metrics['psnr_xy'].append(val_psnr_xy)
                val_metrics['psnr_xyz'].append(val_psnr_xyz)
                val_metrics['l1_loss'].append(val_l1_loss)
                val_metrics['global_step'].append(global_step)
                
                stdout.write("\n")
                print(f'Mean Validation PSNR-SR: {val_psnr_xy}, PSNR-SRC: {val_psnr_xyz}, L1-Loss: {val_l1_loss}')
                
                # Log validation metrics to wandb
                wandb.log({
                    "val/psnr_xy": val_psnr_xy,
                    "val/psnr_xyz": val_psnr_xyz,
                    "val/l1_loss": val_l1_loss
                }, step=global_step)

                # Use weighted score for model selection (PSNR and L1 loss)
                psnr_weight = 0.7
                l1_weight = 0.3
                current_score = (psnr_weight * val_psnr_xyz) - (l1_weight * val_l1_loss)
                best_score = (psnr_weight * best_model['val_psnr_xyz']) - (l1_weight * best_model['val_l1_loss'])

                if current_score > best_score:
                    best_model['epoch'] = epoch + 1
                    best_model['val_psnr_xyz'] = val_psnr_xyz
                    best_model['val_psnr_xy'] = val_psnr_xy
                    best_model['val_l1_loss'] = val_l1_loss
                    
                    print(f"\nNew best model at epoch {epoch+1}!")
                    print(f"PSNR-XY: {val_psnr_xy:.4f}, PSNR-XYZ: {val_psnr_xyz:.4f}, L1 Loss: {val_l1_loss:.4f}")
                    
                    # Save best models
                    torch.save(generator_sr.state_dict(), f'{training_dir}/GSR-best.pth')
                    torch.save(generator_src.state_dict(), f'{training_dir}/GSRC-best.pth')
                    
                    # Full checkpoint for potential resuming
                    torch.save({
                        'epoch': epoch+1,
                        'model_state_dict': generator_sr.state_dict(),
                        'optimizer_state_dict': optimizer_gen_sr.state_dict(),
                        'best_psnr': val_psnr_xyz,
                        'best_l1': val_l1_loss
                    }, f'{training_dir}/GSR-best_full.pth')
                    
                    torch.save({
                        'epoch': epoch+1,
                        'model_state_dict': generator_src.state_dict(),
                        'optimizer_state_dict': optimizer_gen_src.state_dict(),
                        'best_psnr': val_psnr_xyz,
                        'best_l1': val_l1_loss
                    }, f'{training_dir}/GSRC-best_full.pth')
                    
                    # Log best metrics to wandb
                    wandb.run.summary["best_epoch"] = epoch + 1
                    wandb.run.summary["best_psnr_xyz"] = val_psnr_xyz
                    wandb.run.summary["best_psnr_xy"] = val_psnr_xy
                    wandb.run.summary["best_l1_loss"] = val_l1_loss
                    
                    # Additionally log as metrics
                    wandb.log({
                        "best/epoch": epoch + 1,
                        "best/psnr_xyz": val_psnr_xyz,
                        "best/psnr_xy": val_psnr_xy,
                        "best/l1_loss": val_l1_loss
                    }, step=global_step)
                
                # Test on separate test data if enabled
                if args.valTest and 'LR_test' in locals():
                    print(f'Generating some test cubes')
                    with torch.no_grad():
                        test_sr_xy = generator_sr(LR_test)
                        test_sr_xy_np = test_sr_xy.cpu().numpy()
                        image_path = f'./{train_output_dir}/epoch-{epoch+1}/testSRxy.tif'
                        test_sr_xy_out = (test_sr_xy_np + 1) * 127.5
                        tifffile.imwrite(image_path, np.squeeze(test_sr_xy_out.astype('uint8')))
                        
                        # Log test image to wandb
                        wandb.log({
                            "test/sr_xy": wandb.Image(np.squeeze(test_sr_xy_out[0].astype(np.uint8)), 
                                                    caption="Test SR XY"),
                        }, step=global_step)
            
            # Save models periodically
            if (epoch+1) % args.save_freq == 0 or epoch == args.epoch - 1:
                print('Saving network weights (archive)')
                torch.save(generator_sr.state_dict(), f'{training_dir}/GSR-{epoch+1}.pth')
                torch.save(generator_src.state_dict(), f'{training_dir}/GSRC-{epoch+1}.pth')
                
                print('Saving network weights (rewritable checkpoint)')
                torch.save(generator_sr.state_dict(), f'{training_dir}/GSR.pth')
                torch.save(generator_src.state_dict(), f'{training_dir}/GSRC.pth')
                
                print('Saving model (rewritable checkpoint)')
                torch.save({
                    'epoch': epoch+1,
                    'model_state_dict': generator_sr.state_dict(),
                    'optimizer_state_dict': optimizer_gen_sr.state_dict(),
                }, f'{training_dir}/GSR-{epoch+1}_full.pth')
                
                torch.save({
                    'epoch': epoch+1,
                    'model_state_dict': generator_src.state_dict(),
                    'optimizer_state_dict': optimizer_gen_src.state_dict(),
                }, f'{training_dir}/GSRC-{epoch+1}_full.pth')
                
                # Save model checkpoint to wandb
                wandb.save(f'{training_dir}/GSR-{epoch+1}.pth')
                wandb.save(f'{training_dir}/GSRC-{epoch+1}.pth')

            # Update the learning rate after each epoch
            scheduler_gen_sr.step()
            scheduler_gen_src.step()
                
        # After training is completed
        # Create directory for storing metrics
        os.makedirs(f'{train_output_dir}/metrics', exist_ok=True)

        # Save training losses
        train_df = pd.DataFrame(train_losses)
        train_df.to_csv(f'{train_output_dir}/metrics/training_losses.csv', index=False)
        print(f"Training losses saved to {train_output_dir}/metrics/training_losses.csv")

        # Save validation metrics
        val_df = pd.DataFrame(val_metrics)
        val_df.to_csv(f'{train_output_dir}/metrics/validation_metrics.csv', index=False)
        print(f"Validation metrics saved to {train_output_dir}/metrics/validation_metrics.csv")
        
        # Save batch-level losses for more detailed analysis
        batch_df = pd.DataFrame(batch_losses)
        batch_df.to_csv(f'{train_output_dir}/metrics/batch_losses.csv', index=False)
        print(f"Batch-level losses saved to {train_output_dir}/metrics/batch_losses.csv")

        # Create a basic plot as a quick reference
        try:
            import matplotlib.pyplot as plt
            
            # Plot training loss
            plt.figure(figsize=(20, 10))
            
            plt.subplot(2, 2, 1)
            plt.plot(train_df['epoch'], train_df['sr_xy_loss'], 'b-', label='XY Loss')
            plt.plot(train_df['epoch'], train_df['sr_yz_loss'], 'r-', label='YZ Loss')
            plt.plot(train_df['epoch'], train_df['total_loss'], 'g-', label='Total Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('Training Losses')
            plt.legend()
            plt.grid(True)
            
            # Plot validation metrics
            plt.subplot(2, 2, 2)
            plt.plot(val_df['epoch'], val_df['psnr_xy'], 'b-o', label='PSNR XY')
            plt.plot(val_df['epoch'], val_df['psnr_xyz'], 'r-o', label='PSNR XYZ')
            plt.xlabel('Epoch')
            plt.ylabel('PSNR (dB)')
            plt.title('Validation PSNR')
            plt.legend()
            plt.grid(True)
            
            # Plot learning rate schedule
            plt.subplot(2, 2, 3)
            # Extract learning rates - recreate the schedule curve
            epochs = np.arange(1, args.epoch + 1)
            lrs = []
            
            # Create a temporary optimizer and scheduler to simulate LR curve
            temp_optimizer = optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=args.lr)
            temp_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                temp_optimizer, T_0=args.epoch_step, T_mult=2, eta_min=args.eta_min
            )
            
            # Simulate scheduler steps
            for _ in epochs:
                lrs.append(temp_optimizer.param_groups[0]['lr'])
                temp_scheduler.step()
                
            plt.plot(epochs, lrs, 'g-')
            plt.xlabel('Epoch')
            plt.ylabel('Learning Rate')
            plt.title('Cosine Annealing Warm Restarts Learning Rate Schedule')
            plt.grid(True)
            
            # Plot batch-level losses (smoothed)
            plt.subplot(2, 2, 4)
            window_size = 50  # Moving average window for smoothing
            if len(batch_df) > window_size:
                plt.plot(batch_df['global_step'], batch_df['sr_xy_loss'].rolling(window=window_size).mean(), 'b-', alpha=0.7, label=f'XY Loss (MA{window_size})')
                plt.plot(batch_df['global_step'], batch_df['sr_yz_loss'].rolling(window=window_size).mean(), 'r-', alpha=0.7, label=f'YZ Loss (MA{window_size})')
                plt.plot(batch_df['global_step'], batch_df['total_loss'].rolling(window=window_size).mean(), 'g-', alpha=0.7, label=f'Total Loss (MA{window_size})')
                plt.xlabel('Global Step')
                plt.ylabel('Loss')
                plt.title('Batch-level Losses (Moving Average)')
                plt.legend()
                plt.grid(True)
            else:
                plt.title('Not enough batches for smoothed visualization')
                plt.grid(True)
            
            plt.tight_layout()
            plt.savefig(f'{train_output_dir}/metrics/training_summary.png', dpi=300)
            plt.close()
            print(f"Summary plot saved to {train_output_dir}/metrics/training_summary.png")
        except Exception as e:
            print(f"Could not create summary plot: {str(e)}")
        wandb.finish()

    # Test on larger volumes - save as 3D TIFF
    elif args.phase == 'test':
        # Set models to evaluation mode
        generator_sr.eval()
        generator_src.eval()
        
        # Find test files
        test_files = sorted(glob(args.test_dir + '/*.png'))
        if not test_files:
            print(f"No PNG files found in {args.test_dir}")
            return
            
        print(f"Found {len(test_files)} test files")
        
        # Create output directory
        os.makedirs(os.path.join(args.test_save_dir, args.modelName), exist_ok=True)
        
        # Check available RAM before processing
        mem_info = psutil.virtual_memory()
        available_ram_gb = mem_info.available / 1e9
        print(f"Available RAM: {available_ram_gb:.1f} GB")
        
        # Estimate memory needed based on test file size
        test_sample = imageio.imread(test_files[0])
        bytes_per_voxel = 4  # float32
        vol_shape_xy = (len(test_files), test_sample.shape[0] * args.scale, test_sample.shape[1] * args.scale)
        estimated_peak_gb = (vol_shape_xy[0] * vol_shape_xy[1] * vol_shape_xy[2] * bytes_per_voxel * 2) / 1e9
        print(f"Estimated peak memory usage: {estimated_peak_gb:.1f} GB")
        
        # Choose processing method based on available RAM
        if available_ram_gb > estimated_peak_gb * 1.5:  # Add 50% safety margin
            print(f"Sufficient RAM available. Processing volume directly without chunking.")
            output_volume = process_entire_volume(test_files, generator_sr, generator_src, args, device)
        else:
            print(f"Limited RAM available ({available_ram_gb:.1f} GB). Using chunked processing.")
            output_volume = test_memory_efficient_3d(test_files, generator_sr, generator_src, args, device)
        
        # Test metrics calculation if ground truth is available
        gt_dir = os.path.join(os.path.dirname(args.test_dir), 'HR')
        gt_files = sorted(glob(f"{gt_dir}/*.png"))
        
        if gt_files and len(gt_files) * args.scale == output_volume.shape[0]:
            print("Found ground truth files. Calculating metrics...")
            
            # Load ground truth volume
            gt_volume = []
            for gt_file in tqdm(gt_files, desc="Loading ground truth"):
                gt_slice = imageio.imread(gt_file)
                # Ensure same size as output
                if gt_slice.shape != (output_volume.shape[1], output_volume.shape[2]):
                    print(f"Warning: GT size mismatch. Resizing GT from {gt_slice.shape} to match output {output_volume.shape[1:]}")
                    # You might want to add resizing here if needed
                gt_volume.append(gt_slice)
            
            # Convert stacks to volume
            gt_slices = np.stack(gt_volume, axis=0)
            gt_volume = np.zeros((output_volume.shape[0], output_volume.shape[1], output_volume.shape[2]), dtype=np.uint8)
            
            # Interpolate along Z to match output scale
            for y in range(output_volume.shape[1]):
                for x in range(output_volume.shape[2]):
                    gt_profile = gt_slices[:, y, x]
                    gt_volume[:, y, x] = np.interp(
                        np.linspace(0, len(gt_profile)-1, output_volume.shape[0]),
                        np.arange(len(gt_profile)),
                        gt_profile
                    )
            
            # Calculate metrics
            mse = np.mean((output_volume.astype(np.float32) - gt_volume.astype(np.float32)) ** 2)
            psnr = 20 * np.log10(255.0 / np.sqrt(mse)) if mse > 0 else float('inf')
            mae = np.mean(np.abs(output_volume.astype(np.float32) - gt_volume.astype(np.float32)))
            
            print(f"Test metrics - PSNR: {psnr:.4f}, MAE: {mae:.4f}")
            
            # Save metrics to CSV
            metrics_df = pd.DataFrame({
                'metric': ['PSNR', 'MAE', 'MSE'],
                'value': [psnr, mae, mse],
                'model': args.modelName,
                'dataset': os.path.basename(args.test_dir)
            })
            metrics_df.to_csv(f'{args.test_save_dir}/{args.modelName}/test_metrics.csv', index=False)
            
            print(f"Test metrics saved to {args.test_save_dir}/{args.modelName}/test_metrics.csv")
        
        print(f"Testing completed. Results saved to {args.test_save_dir}/{args.modelName}/")

In [None]:
# Add this at the end of your notebook as a separate cell
if __name__ == "__main__":
    args = Args()
    # Set your API key here or pass it as an argument to your script
    args.wandb_api_key = os.environ.get("WANDB_API_KEY")  
    # Optional: specify your wandb entity (username or team name)
    args.wandb_entity = "YOUR_WANDB_USERNAME"  
    main()

Using device: cuda:0


[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


No API key provided. Using default wandb login...


[34m[1mwandb[0m: Currently logged in as: [33mmzz20[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Successfully authenticated with wandb!


2D/3D training specified, datasets will be randomly mini-batched per epoch
2D/3D dataset and training -> data will be fully preloaded into RAM
Loading LR data from: ./Dataset/Bentheimer_mixed_fw90/Train/LR.npy
Loading HR data from: ./Dataset/Bentheimer_mixed_fw90/Train/HR.npy
LR shape: (250, 250, 250), HR shape: (1000, 1000, 1000)
Reading dataset, block size this epoch: 103 x 75 x 75 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 1.000000e-04


Epoch 1/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0731, GSRyzL: 0.0938


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 29.63341407775879, PSNR-SRC: 27.83168716430664, L1-Loss: 0.05053101554512977

New best model at epoch 1!
PSNR-XY: 29.6334, PSNR-XYZ: 27.8317, L1 Loss: 0.0505
Reading dataset, block size this epoch: 110 x 73 x 73 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 9.990232e-05


Epoch 2/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0463, GSRyzL: 0.0504


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 31.902999114990234, PSNR-SRC: 30.073437881469726, L1-Loss: 0.03856193497776985

New best model at epoch 2!
PSNR-XY: 31.9030, PSNR-XYZ: 30.0734, L1 Loss: 0.0386
Reading dataset, block size this epoch: 72 x 90 x 90 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 9.960968e-05


Epoch 3/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0405, GSRyzL: 0.0445


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 32.447698974609374, PSNR-SRC: 30.733896255493164, L1-Loss: 0.03596896827220917

New best model at epoch 3!
PSNR-XY: 32.4477, PSNR-XYZ: 30.7339, L1 Loss: 0.0360
Reading dataset, block size this epoch: 78 x 86 x 86 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 9.912322e-05


Epoch 4/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0384, GSRyzL: 0.0416


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 32.099238586425784, PSNR-SRC: 31.059212112426756, L1-Loss: 0.03483603671193123

New best model at epoch 4!
PSNR-XY: 32.0992, PSNR-XYZ: 31.0592, L1 Loss: 0.0348
Reading dataset, block size this epoch: 174 x 58 x 58 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 9.844487e-05


Epoch 5/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0370, GSRyzL: 0.0387


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 33.36779098510742, PSNR-SRC: 32.40069923400879, L1-Loss: 0.03120445869863033

New best model at epoch 5!
PSNR-XY: 33.3678, PSNR-XYZ: 32.4007, L1 Loss: 0.0312
Reading dataset, block size this epoch: 160 x 60 x 60 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 9.757730e-05


Epoch 6/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0352, GSRyzL: 0.0367


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 33.384098052978516, PSNR-SRC: 32.49186897277832, L1-Loss: 0.030090263485908507

New best model at epoch 6!
PSNR-XY: 33.3841, PSNR-XYZ: 32.4919, L1 Loss: 0.0301
Reading dataset, block size this epoch: 107 x 74 x 74 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 9.652394e-05


Epoch 7/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0327, GSRyzL: 0.0345


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 34.13421096801758, PSNR-SRC: 32.694484329223634, L1-Loss: 0.031065214797854422

New best model at epoch 7!
PSNR-XY: 34.1342, PSNR-XYZ: 32.6945, L1 Loss: 0.0311
Reading dataset, block size this epoch: 31 x 137 x 137 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 9.528894e-05


Epoch 8/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0328, GSRyzL: 0.0353


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 34.13420104980469, PSNR-SRC: 33.17862777709961, L1-Loss: 0.02962447181344032

New best model at epoch 8!
PSNR-XY: 34.1342, PSNR-XYZ: 33.1786, L1 Loss: 0.0296
Reading dataset, block size this epoch: 177 x 57 x 57 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 9.387718e-05


Epoch 9/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0313, GSRyzL: 0.0326


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 34.125144958496094, PSNR-SRC: 33.439842987060544, L1-Loss: 0.02740780673921108

New best model at epoch 9!
PSNR-XY: 34.1251, PSNR-XYZ: 33.4398, L1 Loss: 0.0274
Reading dataset, block size this epoch: 90 x 80 x 80 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 9.229423e-05


Epoch 10/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0295, GSRyzL: 0.0303


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 34.06901473999024, PSNR-SRC: 33.48613510131836, L1-Loss: 0.027213162183761595

New best model at epoch 10!
PSNR-XY: 34.0690, PSNR-XYZ: 33.4861, L1 Loss: 0.0272
Saving network weights (archive)
Saving network weights (rewritable checkpoint)
Saving model (rewritable checkpoint)
Reading dataset, block size this epoch: 66 x 94 x 94 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 9.054634e-05


Epoch 11/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0299, GSRyzL: 0.0309


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 33.183695220947264, PSNR-SRC: 32.38807563781738, L1-Loss: 0.03517596200108528
Reading dataset, block size this epoch: 44 x 115 x 115 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 8.864041e-05


Epoch 12/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0298, GSRyzL: 0.0306


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 34.45735244750976, PSNR-SRC: 34.03599243164062, L1-Loss: 0.027611548081040382

New best model at epoch 12!
PSNR-XY: 34.4574, PSNR-XYZ: 34.0360, L1 Loss: 0.0276
Reading dataset, block size this epoch: 150 x 62 x 62 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 8.658395e-05


Epoch 13/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0301, GSRyzL: 0.0309


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 34.42474365234375, PSNR-SRC: 34.074398803710935, L1-Loss: 0.02640445567667484

New best model at epoch 13!
PSNR-XY: 34.4247, PSNR-XYZ: 34.0744, L1 Loss: 0.0264
Reading dataset, block size this epoch: 28 x 145 x 145 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 8.438508e-05


Epoch 14/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0277, GSRyzL: 0.0283


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 34.88988189697265, PSNR-SRC: 34.20419616699219, L1-Loss: 0.02734937407076359

New best model at epoch 14!
PSNR-XY: 34.8899, PSNR-XYZ: 34.2042, L1 Loss: 0.0273
Reading dataset, block size this epoch: 134 x 66 x 66 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 8.205249e-05


Epoch 15/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0284, GSRyzL: 0.0284


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 34.72739486694336, PSNR-SRC: 34.50938339233399, L1-Loss: 0.026596492156386375

New best model at epoch 15!
PSNR-XY: 34.7274, PSNR-XYZ: 34.5094, L1 Loss: 0.0266
Reading dataset, block size this epoch: 44 x 115 x 115 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 7.959537e-05


Epoch 16/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0278, GSRyzL: 0.0279


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 34.641578674316406, PSNR-SRC: 34.33411483764648, L1-Loss: 0.027170846611261366
Reading dataset, block size this epoch: 109 x 73 x 73 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 7.702343e-05


Epoch 17/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0288, GSRyzL: 0.0293


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 34.064073944091795, PSNR-SRC: 33.96184158325195, L1-Loss: 0.027146874740719796
Reading dataset, block size this epoch: 114 x 71 x 71 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 7.434681e-05


Epoch 18/500:   0%|          | 0/200 [00:00<?, ?it/s]

Mean Epoch Performance: GSRxyL: 0.0275, GSRyzL: 0.0274


Validation:   0%|          | 0/5 [00:00<?, ?it/s]


Mean Validation PSNR-SR: 34.96498565673828, PSNR-SRC: 34.94762496948242, L1-Loss: 0.023847226053476334

New best model at epoch 18!
PSNR-XY: 34.9650, PSNR-XYZ: 34.9476, L1 Loss: 0.0238
Reading dataset, block size this epoch: 74 x 89 x 89 -> 4x


Creating Training Cubes:   0%|          | 0/200 [00:00<?, ?it/s]

Learning Rate: 7.157607e-05


Epoch 19/500:   0%|          | 0/200 [00:00<?, ?it/s]