In [2]:
!pip install pydicom numpy scikit-image pillow scipy SimpleITK

Collecting pydicom
  Using cached pydicom-3.0.1-py3-none-any.whl.metadata (9.4 kB)
Collecting SimpleITK
  Using cached simpleitk-2.5.3-cp311-abi3-win_amd64.whl.metadata (7.3 kB)
Using cached pydicom-3.0.1-py3-none-any.whl (2.4 MB)
Using cached simpleitk-2.5.3-cp311-abi3-win_amd64.whl (18.8 MB)
Installing collected packages: SimpleITK, pydicom

   ---------------------------------------- 0/2 [SimpleITK]
   ---------------------------------------- 0/2 [SimpleITK]
   ---------------------------------------- 0/2 [SimpleITK]
   ---------------------------------------- 0/2 [SimpleITK]
   -------------------- ------------------- 1/2 [pydicom]
   -------------------- ------------------- 1/2 [pydicom]
   -------------------- ------------------- 1/2 [pydicom]
   -------------------- ------------------- 1/2 [pydicom]
   -------------------- ------------------- 1/2 [pydicom]
   -------------------- ------------------- 1/2 [pydicom]
   -------------------- ------------------- 1/2 [pydicom]
   -----

In [1]:
import os  # For file/directory operations (e.g., listing files, creating dirs). What: System interaction module. Why: To handle batch processing of files in a directory. How useful: Automates working with multiple DICOM files without hardcoding paths.
import pydicom  # For reading DICOM files. What: Library for DICOM parsing. Why: DICOMs have metadata and pixels; this extracts them safely. How useful: Ensures accurate HU conversion and handles medical-specific formats.
import numpy as np  # For array math. What: Numerical computing library. Why: Images are arrays; needed for clipping, scaling. How useful: Fast, vectorized operations prevent slow loops, ideal for large images.
import SimpleITK as sitk  # For advanced image processing (not used here, but potential for extensions like registration). What: ITK wrapper. Why: Future-proofing; skipped in this basic script. How useful: For complex tasks like 3D handling if expanded.
from PIL import Image  # For image saving. What: Pillow's Image class. Why: Converts arrays to PNG for output. How useful: Enables easy visualization/storage in web-friendly formats.
from skimage import filters, morphology  # Filters (e.g., Otsu, median), morphology (e.g., disk, opening). What: Image processing submodules. Why: For thresholding, denoising, structuring elements. How useful: Provides ready-made algorithms for noise reduction and shape manipulation without custom code.

In [2]:
def save_as_png(array, path):
    """Utility to save a numpy array as a normalized PNG."""
    # What: Function to normalize and save array as PNG, handling binary or intensity data.
    # Why: Arrays need scaling to 0-255 for image formats; handles different types.
    # How useful: Ensures consistent visualization; binary masks become visible (white/black).
    
    # Handle both binary masks and intensity images
    # What: Checks dtype and normalizes accordingly.
    # Why: Masks (bool) need conversion; intensities may vary in range.
    # How useful: Versatile for all steps; prevents save errors on non-uint8 data.
    if array.dtype == bool:
        img_array = (array * 255).astype(np.uint8)
    else:
        # Standardize range to 0-255 for visualization
        # What: Min-max normalization to [0,255].
        # Why: Raw HU/values may be wide; this fits PNG range.
        # How useful: Makes images displayable; handles uniform arrays gracefully.
        amin = array.min()
        amax = array.max()
        if amax > amin:
            img_array = ((array - amin) / (amax - amin) * 255).astype(np.uint8)
        else:
            img_array = np.zeros_like(array, dtype=np.uint8)
    
    img = Image.fromarray(img_array)  # Converts to PIL Image.
    img.save(path)  # Saves to path.

In [3]:
def process_all_steps():
    # What: Batch function to process DICOMs, saving each preprocessing/segmentation step as PNG in subfolders.
    # Why: Visualizes pipeline progression; organizes outputs.
    # How useful: Easy inspection of intermediates; aids debugging/understanding effects.
    
    input_dir = "E:/Normal/DICOM"  # Input path; adjust to your DICOM folder (e.g., 'D:/CT_Dataset/Normal').
    base_output_dir = "E:/Normal/preprocessing_steps"  # Base output; subfolders per step.
    
    # Define the steps we want to save
    # What: List of step names for subfolders.
    # Why: Structures outputs (e.g., '1_Raw', '6_Final_Stripped').
    # How useful: Clear organization; matches your described sequence.
    steps = [
        "1_Raw",
        "2_HU_Normalized",
        "3_Windowed",
        "4_Denoised",
        "5_Skull_Mask",
        "6_Final_Stripped"
    ]
    
    for step in steps:
        os.makedirs(os.path.join(base_output_dir, step), exist_ok=True)  # Creates subfolders.
    
    files = [f for f in os.listdir(input_dir) if f.endswith('.dcm')]  # Lists DICOMs.
    print(f"Processing {len(files)} files into individual steps...")
    
    for f in files:
        try:
            file_path = os.path.join(input_dir, f)
            filename_base = f.replace('.dcm', '.png')
            
            # Step 1: Raw
            # What: Loads and saves original pixels.
            # Why: Baseline for comparison.
            # How useful: Shows unprocessed state; highlights later improvements.
            ds = pydicom.dcmread(file_path)
            raw_pixel = ds.pixel_array
            save_as_png(raw_pixel, os.path.join(base_output_dir, "1_Raw", filename_base))
            
            # Step 2: HU Normalized
            # What: Applies slope/intercept for HU conversion.
            # Why: Standardizes densities (e.g., air=-1000 HU).
            # How useful: Quantitative; enables tissue-specific analysis.
            slope = float(ds.RescaleSlope) if 'RescaleSlope' in ds else 1
            intercept = float(ds.RescaleIntercept) if 'RescaleIntercept' in ds else 0
            hu_image = raw_pixel * slope + intercept
            save_as_png(hu_image, os.path.join(base_output_dir, "2_HU_Normalized", filename_base))
            
            # Step 3: Windowing (Brain Window)
            # What: Clips HU to brain range (0-80 HU).
            # Why: Enhances soft tissue contrast.
            # How useful: Suppresses bone; focuses on brain for better visuals.
            window_center = 40
            window_width = 80
            min_val = window_center - (window_width / 2)
            max_val = window_center + (window_width / 2)
            windowed = np.clip(hu_image, min_val, max_val)
            save_as_png(windowed, os.path.join(base_output_dir, "3_Windowed", filename_base))
            
            # Step 4: Denoised
            # What: Median filter on normalized windowed image.
            # Why: Reduces noise post-windowing.
            # How useful: Smoother image; improves segmentation accuracy.
            # Normalize to 0-255 first for standard filter behavior
            norm_for_filter = ((windowed - min_val) / (max_val - min_val) * 255).astype(np.uint8)
            denoised = filters.median(norm_for_filter, morphology.disk(1))
            save_as_png(denoised, os.path.join(base_output_dir, "4_Denoised", filename_base))
            
            # Step 5: Skull Mask
            # What: Binary mask from thresholding/morphology.
            # Why: Identifies brain region.
            # How useful: Visualizes isolation; precursor to stripping.
            thresh = filters.threshold_otsu(denoised)
            binary = denoised > thresh
            from skimage.segmentation import clear_border
            bw = clear_border(binary)
            selem = morphology.disk(2)
            opened = morphology.opening(bw, selem)
            
            from skimage.measure import label, regionprops
            label_img = label(opened)
            props = regionprops(label_img)
            if props:
                largest_region = max(props, key=lambda x: x.area)
                mask = label_img == largest_region.label
                from scipy import ndimage as ndi
                mask = ndi.binary_fill_holes(mask)
                mask = morphology.dilation(mask, morphology.disk(3))
            else:
                mask = np.ones_like(denoised, dtype=bool)
                
            save_as_png(mask, os.path.join(base_output_dir, "5_Skull_Mask", filename_base))
            
            # Step 6: Final Stripped
            # What: Applies mask to denoised image.
            # Why: Removes skull for brain-only view.
            # How useful: Clean input for DL; focuses on intracranial features.
            final = denoised * mask
            save_as_png(final, os.path.join(base_output_dir, "6_Final_Stripped", filename_base))
            
        except Exception as e:
            print(f"Error processing {f}: {e}")

In [13]:
# Cell 1: Imports and Setup
import os  # For directory operations and file handling. What: Manages paths and creates folders. Why better: Built-in, no extra deps; more reliable than manual string manipulation for cross-OS compatibility.
from PIL import Image  # For image loading and resizing. What: Handles image I/O and transformations. Why better: Lightweight and efficient for PNG/JPG; faster than OpenCV for simple tasks, no unnecessary features.
import numpy as np  # For array operations if needed. What: Numerical computations. Why better: Vectorized ops are faster than loops; essential if extending to batch processing.

# Define directories
processed_dir = 'E:/Normal/preprocessing_steps/6_Final_Stripped'  # Input: Where preprocessed images are stored. What: Source path. Why better: Centralized; easy to change without code edits.
lr_dir = 'E:/Normal/train_lr'  # Output for low-res images. What: LR storage. Why better: Organized subdirs prevent clutter; follows standard SR dataset structure.
hr_dir = 'E:/Normal/train_hr'  # Output for high-res (original processed). What: HR storage. Why better: Pairs with LR for easy loading in training.

os.makedirs(lr_dir, exist_ok=True)  # Creates LR dir if missing. What: Ensures dir exists. Why better: 'exist_ok' avoids errors; safer than manual checks.
os.makedirs(hr_dir, exist_ok=True)  # Same for HR. What: Dir creation. Why better: Consistent handling; prevents runtime failures.

In [7]:
# Run the batch function.
process_all_steps()  # Executes on all files in input_dir, saving steps as PNGs.

Processing 4427 files into individual steps...


In [14]:
# Cell 2: Create LR-HR Pairs Function
def create_lr_hr_pairs(processed_dir, lr_dir, hr_dir, scale_factor=4):
    # What: Function to generate LR from HR by downsampling and save pairs.
    # Why better: Modular; reusable. Downsampling simulates real LR; better than random noise as it mimics acquisition degradation.
    
    files = [f for f in os.listdir(processed_dir) if f.endswith('.png')]  # Lists PNGs. What: Filters images. Why better: Specific extension avoids non-images; list comp is concise/Pythonic over loops.
    print(f"Found {len(files)} processed images.")  # Logs count. What: User feedback. Why better: Helps debug; simple print over complex logging for this script.
    
    for f in files:  # Loops over files. What: Processes each. Why better: Sequential for simplicity; parallelizable if needed, but overkill for small datasets.
        hr_path = os.path.join(processed_dir, f)  # Full HR path. What: Constructs path. Why better: os.path.join handles OS differences; safer than '/' concatenation.
        hr_img = Image.open(hr_path)  # Loads HR image. What: Opens as PIL Image. Why better: PIL preserves quality; easier than numpy for resizing.
        
        # Create LR by downsampling
        lr_size = (hr_img.width // scale_factor, hr_img.height // scale_factor)  # Computes LR dimensions. What: Integer division for downscale. Why better: Ensures integer sizes; avoids float errors in resizing.
        lr_img = hr_img.resize(lr_size, Image.BICUBIC)  # Resizes with bicubic. What: Downsamples. Why better: Bicubic preserves details better than bilinear/nearest; standard for SR simulation.
        
        # Save LR and HR
        lr_save = os.path.join(lr_dir, f)  # LR save path. What: Same filename for pairing. Why better: Easy matching; no renaming needed.
        hr_save = os.path.join(hr_dir, f)  # HR save path. What: Copies original. Why better: Keeps HR intact; allows direct comparison.
        lr_img.save(lr_save)  # Saves LR. What: PNG export. Why better: Lossless; maintains quality over JPG compression.
        hr_img.save(hr_save)  # Saves HR. What: Copies to HR dir. Why better: Centralizes pairs; avoids reading from processed_dir later.
        
        print(f"Created pair for {f}")  # Logs progress. What: Feedback per file. Why better: Tracks completion; useful for large sets.

In [15]:
# Cell 3: Run the Function
create_lr_hr_pairs(processed_dir, lr_dir, hr_dir, scale_factor=4)  # Executes with 4x scale. What: Calls function. Why better: 4x is common in SR; adjustable param for flexibility.
# After running, check lr_dir and hr_dir for pairs. Each LR is downsampled version of HR; used for training super-resolution models.

Found 4427 processed images.
Created pair for 10000.png
Created pair for 10001.png
Created pair for 10004.png
Created pair for 10005.png
Created pair for 10006.png
Created pair for 10008.png
Created pair for 10009.png
Created pair for 10010.png
Created pair for 10011.png
Created pair for 10012.png
Created pair for 10013.png
Created pair for 10014.png
Created pair for 10015.png
Created pair for 10016.png
Created pair for 10018.png
Created pair for 10019.png
Created pair for 10021.png
Created pair for 10022.png
Created pair for 10023.png
Created pair for 10025.png
Created pair for 10026.png
Created pair for 10027.png
Created pair for 10030.png
Created pair for 10032.png
Created pair for 10034.png
Created pair for 10035.png
Created pair for 10037.png
Created pair for 10040.png
Created pair for 10041.png
Created pair for 10042.png
Created pair for 10043.png
Created pair for 10044.png
Created pair for 10048.png
Created pair for 10051.png
Created pair for 10054.png
Created pair for 10055.png

In [4]:
# Cell 1: Imports for Model Building
!pip install torch torchvision torchaudio
import torch  # For tensor operations and neural networks. What: Core PyTorch library. Why: Builds and trains models. Better: GPU acceleration; flexible over TF for research prototypes.
import torch.nn as nn  # For network modules (Conv2d, etc.). What: NN building blocks. Why: Defines layers. Better: Modular; easy to customize architectures like LapSRN/DRRN.
import torch.nn.functional as F  # For functional ops (ReLU). What: Activation/ops without state. Why: Used in forward passes. Better: Stateless; memory-efficient in recursives.



In [20]:
# Cell 3: DRRN Model Definition
class DRRN(nn.Module):
    # What: Deep Recursive Residual Network class.
    # Why: For post-upsampling refinement with deep, efficient structure.
    # How better: Recursive units with sharing (low params ~297K for 52 layers) vs. non-recursive deep nets (e.g., VDSR with more params); multi-path (2^U) captures complex textures better than single-path.
    
    def __init__(self, num_blocks=1, num_units=25, num_filters=128):
        # What: Constructor for B1U25 config.
        # Why: Sets up recursive structure with global/local residuals.
        # How better: Num_units=25 gives depth without explosion; better than shallower (e.g., SRCNN) for fine details in CT refinement.
        super().__init__()  # What: Parent init.
        # Why: Required for Module.
        # How better: Inherits PyTorch utils; standard.
        self.initial_conv = nn.Conv2d(1, num_filters, kernel_size=3, padding=1)  # What: Input to feature space.
        # Why: Expands channels for processing.
        # How better: 128 filters balance capacity/efficiency; better than 64 (more expressive for refinement).
        self.recursive_unit = nn.Sequential(  # What: Shared unit module.
            # Why: Reused in loop for weight sharing.
            # How better: Sharing keeps params low (~297K total); better than unshared (overparametrized, harder to train).
            nn.ReLU(),  # What: Pre-activation ReLU.
            # Why: Non-linearity before conv.
            # How better: Pre-act stabilizes gradients in deep recursives; better than post-act (smoother optimization).
            nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1, bias=False),  # What: First conv in unit.
            # Why: Feature transform without bias (common in residuals).
            # How better: No bias reduces params slightly; padding keeps size, better than strided.
            nn.ReLU(),  # What: Mid-unit activation.
            # Why: Between convs for complexity.
            # How better: Allows non-linear combos; ReLU fast/simple over Leaky (per paper).
            nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1, bias=False)  # What: Second conv.
            # Why: Pair completes unit.
            # How better: Two convs per unit > one (depth); no bias for residual norm.
        )
        self.final_conv = nn.Conv2d(num_filters, 1, kernel_size=3, padding=1)  # What: Compresses to residual output.
        # Why: Predicts refinement map.
        # How better: 3x3 local; reduces channels efficiently.
        self.num_units = num_units  # What: Stores loop count.
        # Why: For forward iteration.
        # How better: Configurable; easy scaling vs. hardcoded.
    
    def forward(self, x):
        # What: Forward with recursive local/global residuals.
        # Why: Refines input via multi-path learning.
        # How better: 2^25 paths capture diverse features with few params; better than non-residual (vanishing gradients in 52 layers).
        features = self.initial_conv(x)  # What: Base features.
        # Why: Prepares for recursion.
        # How better: Channel expansion; better start than raw input.
        for _ in range(self.num_units):  # What: Loops units (25x).
            # Why: Builds effective depth (52 layers).
            # How better: Loop with sharing > explicit layers (code/param efficient).
            res = self.recursive_unit(features)  # What: Computes unit output.
            # Why: Local residual block.
            # How better: Pre-act structure > standard ResNet (better for very deep).
            features = features + res  # What: Local add.
            # Why: Skip connection per unit.
            # How better: Accumulates refinements; stabilizes gradients.
        res = self.final_conv(features)  # What: Final residual predict.
        # Why: Maps back to 1ch.
        # How better: Post-recursion; leverages all depth.
        return x + res  # What: Global residual add.
        # Why: Output = input + learned residual.
        # How better: Preserves high-freq from input; better for SR (avoids over-smoothing).

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import os
from PIL import Image  # For loading images; replace if using other formats

# Assume LapSRN class from Cell 1, but with fix to global_up padding
class LapSRN(nn.Module):
    # What: Defines the Laplacian Pyramid Super-Resolution Network class, inheriting from nn.Module.
    # Why: To create a model for progressive image upsampling using a pyramid structure.
    # How better: Unlike single-stage SR models (e.g., SRCNN), LapSRN's pyramid allows multi-scale predictions in one pass, reducing artifacts and improving efficiency by building incrementally; better than GANs for stability without adversarial training.
    def __init__(self, scale_factor=4, num_recursive_blocks=5, num_layers_per_block=5):
        # What: Constructor to initialize model parameters and layers.
        # Why: Sets up configurable hyperparameters like scale and block depths for flexibility.
        # How better: Parameterized init allows easy tuning (e.g., D5R5 config); better than fixed architectures as it adapts to different scales without redesign.
        super().__init__() # What: Calls parent class constructor.
        # Why: Required for nn.Module subclasses to initialize properly.
        # How better: Ensures inheritance of PyTorch functionalities like .to(device); standard and error-free over manual init.
        self.scale_factor = scale_factor # What: Stores the total upsampling factor (e.g., 4 for 4x SR).
        # Why: To determine number of pyramid levels.
        # How better: Explicit storage allows dynamic computation; better than hardcoding for multi-scale support.
        self.pyramid_levels = int(np.log2(scale_factor)) # What: Calculates pyramid levels using log base 2 (e.g., 2 for 4x).
        # Why: Each level doubles resolution progressively.
        # How better: Math-based derivation ensures correctness for power-of-2 scales; more efficient than arbitrary levels, reducing computation vs. non-pyramid nets.
        self.initial_conv = nn.Conv2d(1, 64, kernel_size=3, padding=1) # What: Initial convolutional layer from 1 channel (grayscale CT) to 64 features.
        # Why: Extracts base features from input.
        # How better: 3x3 kernel is local/efficient; padding='same' preserves spatial dims, better than larger kernels (less params) or no padding (size loss).
       
        # What: Creates a list of recursive blocks using ModuleList.
        # Why: For feature embedding with multiple blocks of conv layers.
        # How better: ModuleList allows PyTorch to track params; loops create independent layers, preventing unintended weight sharing issues in comprehensions; better than flat stack (modular for debugging).
        self.recursive_blocks = nn.ModuleList()
        for _ in range(num_recursive_blocks): # What: Outer loop for num_recursive_blocks (5).
            # Why: Builds each block separately.
            # How better: Ensures fresh instances per block; better than list comp with * (avoids syntax errors and sharing).
            layers = [] # What: Temp list for layers in one block.
            # Why: To build sequential per block.
            # How better: Accumulates layers dynamically; flexible for varying depths.
            for _ in range(num_layers_per_block): # What: Inner loop for num_layers_per_block (5).
                # Why: Adds conv + activation pairs per block.
                # How better: Repeated convs build depth efficiently; better than shallow nets (captures complex features with fewer params via recursion).
                layers.append(nn.Conv2d(64, 64, kernel_size=3, padding=1)) # What: Adds 3x3 conv layer.
                # Why: Feature transformation.
                # How better: Consistent 64 channels keep dims; padding maintains size, better than strided (no downsampling needed here).
                layers.append(nn.LeakyReLU(0.2)) # What: Adds LeakyReLU activation.
                # Why: Introduces non-linearity.
                # How better: alpha=0.2 allows small negatives, mitigating dying ReLU; better than ReLU (avoids zero gradients) or sigmoid (no vanishing).
            self.recursive_blocks.append(nn.Sequential(*layers)) # What: Appends Sequential of layers to list.
            # Why: Groups block for easy forward call.
            # How better: Sequential simplifies forward; * unpacks list, clean syntax over manual add_module.
       
        self.up_feature = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1) # What: Transpose conv for feature upsampling.
        # Why: Learned 2x upscale in feature branch.
        # How better: Kernel=4/stride=2 is standard for artifact-free up; better than PixelShuffle (simpler) or interp (learned filters capture patterns).
        self.residual = nn.Conv2d(64, 1, kernel_size=3, padding=1) # What: Conv to predict residual (1 channel).
        # Why: Sub-band residual for detail addition.
        # How better: Reduces to 1ch efficiently; better than direct predict (focuses on high-freq, pyramid advantage).
        self.up_image = nn.ConvTranspose2d(1, 1, kernel_size=4, stride=2, padding=1) # What: Transpose conv for image branch upsampling.
        # Why: Upscales current reconstruction.
        # How better: Parallel to features; kernel/stride same as up_feature for alignment, better than shared layer (branch-specific).
        self.global_up = nn.ConvTranspose2d(1, 1, kernel_size=4, stride=scale_factor, padding=0) # FIXED: Changed padding=1 to 0 to ensure output size = scale_factor * input_size (e.g., 4*128=512, not 510).
        # Why: Direct global upsampling from input.
        # How better: Stride=scale_factor (e.g., 4) for end-to-end skip; better than no global (stabilizes training, reduces error propagation).
    def forward(self, x):
        # What: Forward pass computing multi-scale outputs.
        # Why: Generates predictions at each level + final.
        # How better: Progressive upscale reduces hallucination; better than single-pass large deconv (less checkerboard artifacts).
        features = self.initial_conv(x) # What: Applies initial conv.
        # Why: Base feature extraction.
        # How better: Starts with rich reps; better than direct input use (learns low-level edges).
        for block in self.recursive_blocks: # What: Loops over blocks.
            # Why: Refines features recursively.
            # How better: Iterative refinement; better than one-pass deep (residuals prevent degradation).
            res = features # What: Stores residual.
            # Why: For local skip connection.
            # How better: Identity mapping; easier optimization than non-residual.
            features = block(features) # What: Applies block.
            # Why: Transforms features.
            # How better: Sequential call; clean over manual layer-by-layer.
            features = res + features # What: Adds residual.
            # Why: Local residual learning.
            # How better: Stabilizes gradients; better than plain add (prevents vanishing in depth).
       
        predictions = [] # What: List to store level outputs.
        # Why: For multi-scale supervision/loss.
        # How better: Enables deep supervision; improves training vs. final loss only.
        current = x # What: Initializes image branch with input.
        # Why: For progressive reconstruction.
        # How better: Starts from LR; builds up, better than all from features (preserves original signal).
        for level in range(self.pyramid_levels): # What: Loop over levels (e.g., 2).
            # Why: Handles each 2x step.
            # How better: Divide upscale; reduces complexity per step vs. one big upscale.
            up_f = self.up_feature(features) # What: Upsamples features.
            # Why: Prepares for level residual.
            # How better: Learned up; captures context better than fixed.
            res = self.residual(up_f) # What: Predicts residual.
            # Why: Adds details at this scale.
            # How better: Focuses on diffs; efficient for SR.
            up_i = self.up_image(current) # What: Upsamples current recon.
            # Why: Coarse up for addition.
            # How better: Separate branch; optimizes for image vs. features.
            recon = up_i + res # What: Reconstructs level.
            # Why: Combines coarse + detail.
            # How better: Residual recon; preserves info better than overwrite.
            predictions.append(recon) # What: Stores level pred.
            # Why: For supervision.
            # How better: Multi-loss; improves intermediates.
            current = recon # What: Updates for next level.
            # Why: Progressive build.
            # How better: Chains levels; accumulates accuracy.
            features = up_f # What: Updates features.
            # Why: Passes upsampled to next.
            # How better: Hierarchical; reuses refined feats.
        global_u = self.global_up(x) # What: Direct upscale from input.
        # Why: Global residual.
        # How better: Long skip; stabilizes end-to-end.
        output = global_u + predictions[-1] # What: Final output.
        # Why: Enhances with global.
        # How better: Combines paths; better than local only.
        return predictions + [output] # What: Returns all preds + final.
        # Why: For flexible loss (sum weighted).
        # How better: Enables training on intermediates; superior performance.

In [8]:
# Example Custom Dataset (adjust paths and loading as per your data)
class CustomSRDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None, crop_size=None):
        # What: Dataset for paired LR/HR images.
        # Why: Loads CT slices for SR training.
        # How better: Handles variable sizes; optional crop for fixed patches.
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.lr_files = sorted(os.listdir(lr_dir))
        self.transform = transform
        self.crop_size = crop_size  # Optional: (lr_crop_h, lr_crop_w) for random crops
    
    def __len__(self):
        return len(self.lr_files)
    
    def __getitem__(self, idx):
        lr_path = os.path.join(self.lr_dir, self.lr_files[idx])
        hr_path = os.path.join(self.hr_dir, self.lr_files[idx])  # Assume same filenames
        
        lr = Image.open(lr_path).convert('L')  # Grayscale
        hr = Image.open(hr_path).convert('L')
        
        if self.transform:
            lr = self.transform(lr)
            hr = self.transform(hr)
        
        # Optional random crop for fixed size (better for SR to avoid padding)
        if self.crop_size:
            crop_h, crop_w = self.crop_size
            max_i = lr.shape[1] - crop_h
            max_j = lr.shape[2] - crop_w
            i = torch.randint(0, max_i + 1, (1,)).item() if max_i >= 0 else 0
            j = torch.randint(0, max_j + 1, (1,)).item() if max_j >= 0 else 0
            lr = lr[:, i:i+crop_h, j:j+crop_w]
            # HR crop: scale_factor=4, so 4x larger crop
            hr_crop_h = crop_h * 4
            hr_crop_w = crop_w * 4
            hr_i = i * 4
            hr_j = j * 4
            hr = hr[:, hr_i:hr_i+hr_crop_h, hr_j:hr_j+hr_crop_w]
        
        return lr, hr

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import os
from PIL import Image  # For loading images; replace if using other formats

class CharbonnierLoss(nn.Module):
    # What: Defines Charbonnier loss for robust regression.
    # Why: Approximates L1 but differentiable; better for outliers in SR than MSE.
    # How better: epsilon prevents zero gradient issues; more stable than pure L1.
    def __init__(self, epsilon=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.epsilon = epsilon
   
    def forward(self, pred, target):
        # What: Computes mean Charbonnier over batch.
        # Why: Pixel-wise robust diff.
        # How better: Mean reduces to scalar; works on variable sizes if shapes match.
        diff = pred - target
        loss = torch.mean(torch.sqrt(diff * diff + self.epsilon * self.epsilon))
        return loss

def pad_collate(batch):
    # What: Custom collate to pad variable-height images (widths assumed same).
    # Why: Fixes RuntimeError by making tensors stackable.
    # How better: Centers padding; preserves content vs. resizing (avoids distortion); zeros pad assumes black background.
    lrs, hrs = list(zip(*batch))
    
    # Pad LRs
    max_h_lr = max(lr.shape[1] for lr in lrs)  # shape[1]=H for [C,H,W]
    lrs_padded = []
    for lr in lrs:
        pad_top = (max_h_lr - lr.shape[1]) // 2
        pad_bottom = max_h_lr - lr.shape[1] - pad_top
        lrs_padded.append(F.pad(lr, (0, 0, pad_top, pad_bottom)))  # pad (left, right, top, bottom)
    
    # Pad HRs similarly
    max_h_hr = max(hr.shape[1] for hr in hrs)
    hrs_padded = []
    for hr in hrs:
        pad_top = (max_h_hr - hr.shape[1]) // 2
        pad_bottom = max_h_hr - hr.shape[1] - pad_top
        hrs_padded.append(F.pad(hr, (0, 0, pad_top, pad_bottom)))
    
    return torch.stack(lrs_padded), torch.stack(hrs_padded)


# Define transforms (normalize if needed, e.g., CT Hounsfield units)
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize(mean=0.5, std=0.5)  # Uncomment if needed
])

# Create Dataset and DataLoader with collate_fn
train_dataset = CustomSRDataset(lr_dir='E:/Normal/train_lr', hr_dir='E:/Normal/train_hr', transform=transform, crop_size=(32, 128))  # Example crop: LR 32x128, HR 128x512; adjust
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=pad_collate)  # Use pad_collate; remove if using crop_size to make all fixed

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

lapsrn = LapSRN(scale_factor=4).to(device)

# Parameter check (should print ~1,056,194 as you saw)
print(f"Model has {sum(p.numel() for p in lapsrn.parameters())} parameters")

optimizer = torch.optim.Adam(lapsrn.parameters(), lr=1e-4)
criterion = CharbonnierLoss()
weights = [0.5, 1.0]  # For two intermediate levels (2x and 4x)
num_epochs = 5

for epoch in range(num_epochs):
    lapsrn.train()
    total_loss = 0
    for lr, hr in train_loader:
        lr, hr = lr.to(device), hr.to(device)
        predictions = lapsrn(lr)
       
        # Multi-scale loss
        loss = 0
        current_scale = 2  # Start at 2 for first level (2x)
        for i, pred in enumerate(predictions[:-1]):  # Intermediate levels
            target_size = (lr.shape[2] * current_scale, lr.shape[3] * current_scale)  # H, W
            target_resized = F.interpolate(
                hr,
                size=target_size,
                mode='bicubic',  # Better for images
                align_corners=False
            )
            loss += weights[i] * criterion(pred, target_resized)
            current_scale *= 2
       
        # Final output loss
        loss += criterion(predictions[-1], hr)
       
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
   
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

# Save model
torch.save(lapsrn.state_dict(), 'E:/Normal/models/lapsrn.pth')

Model has 990051 parameters
Epoch [1/5], Loss: 0.1962


FileNotFoundError: [Errno 2] No such file or directory: 'E:/Normal/train_hr\\10130.png'