In [2]:
import os
import sys
import traceback
from pathlib import Path

import numpy as np
from PIL import Image
import pydicom

# =========================================================
# 1. SETTINGS
# =========================================================

# Where all the patient folders are living right now
DATA_FOLDER = Path("/blue/eel6935/anany.sharma/Final_Project_EEL6935/ProstateMRI_T2")

# The name of the new folder we will make for the pictures
SAVE_FOLDER = "png_conv"

# We ignore scans that have fewer than 3 slices (too small)
SMALLEST_ALLOWED = 3

# =========================================================
# 2. HELPER TOOLS
# =========================================================

def get_patient_folders(main_folder: Path):
    """
    Finds all the sub-folders inside the main folder.
    """
    # Look for items that are actual directories
    folders = [x for x in main_folder.iterdir() if x.is_dir()]
    folders.sort()
    return folders


def make_image_bright_and_clear(image_3d: np.ndarray) -> np.ndarray:
    """
    Fixes the brightness of the 3D scan so it looks good as a picture.
    It stretches the pixel values to fit between 0 and 255.
    """
    # Convert to decimal numbers for math
    data = image_3d.astype(np.float32)
    
    # Find the very dark (1%) and very bright (99%) dots
    low_val, high_val = np.percentile(data, [1, 99])

    # Safety check: if the math breaks, just use the simple min/max
    if high_val <= low_val:
        min_v, max_v = data.min(), data.max()
        if max_v <= min_v:
            # If the image is empty, return a black image
            return np.zeros_like(data, dtype=np.uint8)
        
        # Stretch min to 0 and max to 1
        data = (data - min_v) / (max_v - min_v)
        # Multiply by 255 to make it a standard picture number
        data = (data * 255.0).clip(0, 255).astype(np.uint8)
        return data

    # Clip the super bright outliers so they don't ruin the contrast
    data = np.clip(data, low_val, high_val)
    # Stretch to 0-1 range
    data = (data - low_val) / (high_val - low_val + 1e-8)
    # Convert to 0-255 range
    data = (data * 255.0).clip(0, 255).astype(np.uint8)
    return data


# [Image of MRI anatomical planes]

def rate_this_scan(header_info, file_count: int) -> float:
    """
    This function gives points to a scan to decide if it is the "best" one.
    We want 'T2' type scans that are cut horizontally ('Axial').
    """
    points = 0.0
    
    # Get the description text and make it lowercase
    desc_text = str(getattr(header_info, "SeriesDescription", "")).lower()
    machine_type = str(getattr(header_info, "Modality", "")).lower()

    # Give 5 points if it says T2 (the type of MRI contrast we want)
    if "t2" in desc_text:
        points += 5.0
    
    # Give 3 points if it is cut horizontally (Axial / Transverse)
    # This helps us ignore scans cut from the side or front.
    if "tra" in desc_text or "ax" in desc_text or "trans" in desc_text:
        points += 3.0
    
    # Give 1 point if it's definitely an MRI machine
    if machine_type == "mr":
        points += 1.0

    # Give a tiny bonus for having more slices (bigger is usually better)
    points += file_count / 100.0 

    return points


def read_dicom_files(file_list):
    """
    Reads a bunch of .dcm files and stacks them into a 3D block.
    """
    loaded_slices = []
    
    for current_file in file_list:
        try:
            # Read the medical file
            data = pydicom.dcmread(str(current_file))
            
            # Find the "page number" of this slice so we can sort them later
            page_num = int(getattr(data, "InstanceNumber", 0))

            # Get the raw picture dots
            pixels = data.pixel_array.astype(np.float32)

            # Apply any math corrections saved in the file header
            slope = float(getattr(data, "RescaleSlope", 1.0))
            intercept = float(getattr(data, "RescaleIntercept", 0.0))
            pixels = pixels * slope + intercept

            loaded_slices.append((page_num, pixels))
        except Exception as error:
            print(f"    ⚠️ Oops, could not read {current_file}: {error}")
            traceback.print_exc()

    if not loaded_slices:
        return None, None

    # Put the pages in order (Page 1, Page 2, Page 3...)
    loaded_slices.sort(key=lambda x: x[0])
    
    # Extract just the page numbers
    page_nums = [s[0] for s in loaded_slices]
    
    # Stack the pixel arrays on top of each other like pancakes
    # Result shape: (Height, Width) -> (Count, Height, Width)
    volume_block = np.stack([s[1] for s in loaded_slices], axis=0)
    
    return volume_block, page_nums


# =========================================================
# 3. MAIN PROGRAM
# =========================================================

def run_conversion():
    # 1. Find all patients
    all_patients = get_patient_folders(DATA_FOLDER)
    print(f"Found {len(all_patients)} patient folders in {DATA_FOLDER}")

    for patient_path in all_patients:
        print(f"\n=== Working on: {patient_path.name} ===")

        # Check if we already finished this patient
        already_done = list(patient_path.rglob("*.png"))
        if len(already_done) > 0:
            print(f"  ✅ Found {len(already_done)} PNGs already. Skipping.")
            continue

        # Find all .dcm files for this patient
        dicom_files = list(patient_path.rglob("*.dcm"))
        if not dicom_files:
            print("  ⚠️ No medical files found here. Skipping.")
            continue

        print(f"  Found {len(dicom_files)} raw files.")

        # 2. Sort files into groups (called Series)
        # We use a dictionary to group them by their unique ID
        groups = {} 
        for f in dicom_files:
            try:
                # Read just the header info (fast)
                header = pydicom.dcmread(str(f), stop_before_pixels=True)
                
                # Get the Series ID
                uid = getattr(header, "SeriesInstanceUID", None) or "unknown_id"
                
                if uid not in groups:
                    groups[uid] = {"files": [], "header": header}
                groups[uid]["files"].append(f)
            except Exception as e:
                print(f"    ⚠️ Header error on {f}: {e}")
                traceback.print_exc()

        if not groups:
            print("  ⚠️ No valid groups found. Skipping.")
            continue

        # 3. Pick the best group (The T2 Axial one)
        winner_id = None
        highest_score = -1000000.0
        
        for uid, info in groups.items():
            head = info["header"]
            files = info["files"]
            
            # Calculate score
            score = rate_this_scan(head, len(files))
            desc = str(getattr(head, "SeriesDescription", ""))
            
            print(f"  Group {uid} | {len(files)} imgs | Desc='{desc}' | Score={score:.2f}")
            
            # Update the winner if this one is better and big enough
            if len(files) >= SMALLEST_ALLOWED and score > highest_score:
                highest_score = score
                winner_id = uid

        if winner_id is None:
            print(f"  ⚠️ No group was good enough (needs {SMALLEST_ALLOWED}+ slices). Skipping.")
            continue

        winner_data = groups[winner_id]
        winner_files = winner_data["files"]
        winner_desc = str(getattr(winner_data["header"], "SeriesDescription", ""))
        
        print(f"  👉 Winner is: {winner_id} ({len(winner_files)} slices, '{winner_desc}')")

        # 4. Load the winner into memory
        volume_3d, _ = read_dicom_files(winner_files)
        if volume_3d is None:
            print("  ⚠️ Failed to load pixel data. Skipping.")
            continue

        print(f"  Raw size: {volume_3d.shape}")

        # 5. Make it look nice (0-255 brightness)
        clean_volume = make_image_bright_and_clear(volume_3d)
        print(f"  Clean size: {clean_volume.shape}, Type={clean_volume.dtype}")

        # 6. Save as PNG pictures
        target_folder = patient_path / SAVE_FOLDER
        target_folder.mkdir(parents=True, exist_ok=True)

        count = clean_volume.shape[0]
        for i in range(count):
            # Get one slice
            single_slice = clean_volume[i] 
            # Make a picture object
            pic = Image.fromarray(single_slice)
            # Save it
            save_name = target_folder / f"slice_{i:03d}.png"
            pic.save(save_name)

        print(f"  ✅ Saved {count} pictures to {target_folder}")

    print("\nAll done!")


if __name__ == "__main__":
    run_conversion()

Found 842 patient folders in /blue/eel6935/anany.sharma/Final_Project_EEL6935/ProstateMRI_T2

=== Working on: Prostate-MRI-US-Biopsy-0001 ===
  ✅ Found 120 PNGs already. Skipping.

=== Working on: Prostate-MRI-US-Biopsy-0002 ===
  ✅ Found 120 PNGs already. Skipping.

=== Working on: Prostate-MRI-US-Biopsy-0003 ===
  ✅ Found 120 PNGs already. Skipping.

=== Working on: Prostate-MRI-US-Biopsy-0005 ===
  ✅ Found 120 PNGs already. Skipping.

=== Working on: Prostate-MRI-US-Biopsy-0006 ===
  ✅ Found 120 PNGs already. Skipping.

=== Working on: Prostate-MRI-US-Biopsy-0007 ===
  ✅ Found 120 PNGs already. Skipping.

=== Working on: Prostate-MRI-US-Biopsy-0008 ===
  ✅ Found 120 PNGs already. Skipping.

=== Working on: Prostate-MRI-US-Biopsy-0009 ===
  ✅ Found 120 PNGs already. Skipping.

=== Working on: Prostate-MRI-US-Biopsy-0012 ===
  ✅ Found 120 PNGs already. Skipping.

=== Working on: Prostate-MRI-US-Biopsy-0013 ===
  ✅ Found 120 PNGs already. Skipping.

=== Working on: Prostate-MRI-US-Biop

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from pathlib import Path
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Where to save models
DATA_ROOT = Path("/blue/eel6935/anany.sharma/Final_Project_EEL6935/ProstateMRI_T2_MR_png")
MODEL_DIR = DATA_ROOT / "models_interp"
MODEL_DIR.mkdir(parents=True, exist_ok=True)


Using device: cuda


In [None]:
import os
import random
import math
from pathlib import Path
from collections import defaultdict
import time 

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, UnidentifiedImageError
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import GradScaler, autocast

# Attempt to load the specialized medical image library
try:
    import pydicom
except ImportError as e:
    raise ImportError("The 'pydicom' library is required.") from e

# Determine the processing hardware
print("PyTorch Version:", torch.__version__)
processing_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
amp_dev_type = "cuda" if torch.cuda.is_available() else "cpu"
gpu_count = torch.cuda.device_count()
print("Using hardware:", processing_device, "| Available GPUs:", gpu_count)

USE_MULTI_GPU = (processing_device.type == "cuda" and gpu_count > 1)
if USE_MULTI_GPU:
    print(f"✅ Parallel Processing enabled across {gpu_count} GPUs")

# --- Utility Functions ---
def extract_model_weights(model_instance):
    return model_instance.module.state_dict() if isinstance(model_instance, nn.DataParallel) else model_instance.state_dict()

def inject_model_weights(model_instance, weight_dict):
    if isinstance(model_instance, nn.DataParallel):
        model_instance.module.load_state_dict(weight_dict)
    else:
        model_instance.load_state_dict(weight_dict)

# ============================================================
# 1. CONFIGURATION
# ============================================================

RAW_SCAN_PATH = Path("/blue/eel6935/anany.sharma/Final_Project_EEL6935/ProstateMRI_T2")
PROCESSED_PNG_PATH = Path("/blue/eel6935/anany.sharma/Final_Project_EEL6935/ProstateMRI_T2_MR_png")

TRAIN_SET_HR_DIR = PROCESSED_PNG_PATH / "training_data" / "HR"
VALIDATION_SET_HR_DIR = PROCESSED_PNG_PATH / "validation_data" / "HR"
TEST_SET_HR_DIR = PROCESSED_PNG_PATH / "testing_data" / "HR"

for d in [TRAIN_SET_HR_DIR, VALIDATION_SET_HR_DIR, TEST_SET_HR_DIR]:
    d.mkdir(parents=True, exist_ok=True)

TARGET_HEIGHT, TARGET_WIDTH = 256, 256
TRAIN_RATIO, VALIDATION_RATIO, TEST_RATIO = 0.7, 0.15, 0.15

BATCH_SIZE_PER_STEP = 8
DATA_WORKER_THREADS = 4 
TOTAL_TRAINING_EPOCHS = 20 
LEARNING_RATE = 1e-4 

random.seed(42)

# ============================================================
# 2. DATA PREPARATION (Robust)
# ============================================================

def SafeImageCheck(path):
    """
    Tries to open an image to verify it is not corrupt.
    Returns True if safe, False if corrupt.
    """
    try:
        with Image.open(path) as img:
            img.verify() 
        return True
    except (OSError, UnidentifiedImageError):
        print(f"⚠️ Found corrupt image: {path.name}. Skipping associated triplets.")
        return False

def read_patient_scan_volume(patient_directory, target_dim=(256, 256)):
    """Reads DICOM files, normalizes, and stacks them into a volume."""
    dicom_files = sorted(patient_directory.rglob("*.dcm"))
    if len(dicom_files) == 0: return None

    slice_data_list = []
    for file_path in dicom_files:
        try:
            ds = pydicom.dcmread(file_path, stop_before_pixels=False)
        except Exception: continue
        
        if not hasattr(ds, "pixel_array"): continue
        
        raw_image = ds.pixel_array.astype(np.float32)
        rescale_slope = float(getattr(ds, "RescaleSlope", 1.0))
        rescale_intercept = float(getattr(ds, "RescaleIntercept", 0.0))
        processed_image = raw_image * rescale_slope + rescale_intercept
        instance_number = int(getattr(ds, "InstanceNumber", len(slice_data_list)))
        slice_data_list.append((instance_number, processed_image))

    if len(slice_data_list) < 3: return None

    slice_data_list.sort(key=lambda x: x[0])
    
    resized_slices_list = []
    for inst, img in slice_data_list:
        img_min = img.min()
        img_max = img.max()
        if img_max > img_min:
            img_norm = (img - img_min) / (img_max - img_min)
        else:
            img_norm = np.zeros_like(img)
        
        img_uint8 = (img_norm * 255).astype(np.uint8)
        pil_image = Image.fromarray(img_uint8)
        resized_pil_image = pil_image.resize((target_dim[1], target_dim[0]), resample=Image.BILINEAR)
        resized_np_array = np.array(resized_pil_image, dtype=np.float32)
        resized_slices_list.append(resized_np_array)
    
    stacked_volume = np.stack(resized_slices_list, axis=0)
    volume_min = float(stacked_volume.min())
    volume_max = float(stacked_volume.max())
    if volume_max <= volume_min: return None 
        
    stacked_volume = (stacked_volume - volume_min) / (volume_max - volume_min + 1e-6)
    return (stacked_volume * 255.0).clip(0, 255).astype(np.uint8)

def process_and_organize_data():
    """Divides patient folders and creates robust image triplets."""
    all_patient_dirs = sorted([d for d in RAW_SCAN_PATH.iterdir() if d.is_dir()])
    if len(all_patient_dirs) == 0:
        raise RuntimeError(f"No patient folders found in {RAW_SCAN_PATH}")

    random.shuffle(all_patient_dirs)
    total_patients = len(all_patient_dirs)
    
    n_train_pats = int(total_patients * TRAIN_RATIO)
    n_val_pats = int(total_patients * VALIDATION_RATIO)

    train_patients = all_patient_dirs[:n_train_pats]
    validation_patients = all_patient_dirs[n_train_pats:n_train_pats+n_val_pats]
    test_patients = all_patient_dirs[n_train_pats+n_val_pats:]

    patient_group_map = {}
    for p in train_patients: patient_group_map[p.name] = "train"
    for p in validation_patients: patient_group_map[p.name] = "val"
    for p in test_patients: patient_group_map[p.name] = "test"

    print(f"Total patients: {total_patients} | Train={len(train_patients)}, Val={len(validation_patients)}, Test={len(test_patients)}")

    train_slice_triplets, val_slice_triplets, test_slice_triplets = [], [], []

    split_to_output_dir = {
        "train": TRAIN_SET_HR_DIR,
        "val": VALIDATION_SET_HR_DIR,
        "test": TEST_SET_HR_DIR,
    }

    for patient_folder in tqdm(all_patient_dirs, desc="Preparing Data"):
        patient_id = patient_folder.name
        split_group = patient_group_map[patient_id]
        output_folder = split_to_output_dir[split_group]

        scan_volume = read_patient_scan_volume(patient_folder)
        if scan_volume is None: continue

        num_slices, H, W = scan_volume.shape
        if num_slices < 3: continue

        png_path_list = []
        for slice_index in range(num_slices):
            png_filename = f"{patient_id}_slice_{slice_index:03d}.png"
            png_filepath = output_folder / png_filename

            if not png_filepath.exists():
                image_slice = Image.fromarray(scan_volume[slice_index])
                image_slice.save(png_filepath)
            png_path_list.append(png_filepath)

        # Create triplets ONLY if all images are safe
        triplets_for_patient = []
        for i in range(1, num_slices - 1):
            prev_p = png_path_list[i - 1]
            mid_p = png_path_list[i]
            next_p = png_path_list[i + 1]
            
            # Robust Check: Only add triplet if files are valid
            if SafeImageCheck(prev_p) and SafeImageCheck(mid_p) and SafeImageCheck(next_p):
                triplets_for_patient.append((prev_p, mid_p, next_p))

        if split_group == "train": train_slice_triplets.extend(triplets_for_patient)
        elif split_group == "val": val_slice_triplets.extend(triplets_for_patient)
        else: test_slice_triplets.extend(triplets_for_patient)

    print(f"Data Ready: Train={len(train_slice_triplets)}, Val={len(val_slice_triplets)}, Test={len(test_slice_triplets)}")
    return train_slice_triplets, val_slice_triplets, test_slice_triplets

# ============================================================
# 3. PYTORCH DATA HANDLING
# ============================================================

def load_image_to_tensor(image_path, target_dim=(TARGET_HEIGHT, TARGET_WIDTH)):
    try:
        img = Image.open(image_path).convert("F") 
        if img.size != (target_dim[1], target_dim[0]):
            img = img.resize((target_dim[1], target_dim[0]), resample=Image.BILINEAR)
        numpy_array = np.array(img, dtype=np.float32) / 255.0
        return torch.from_numpy(numpy_array).unsqueeze(0)
    except Exception:
        # Fallback for safety (should be caught by SafeImageCheck mostly)
        return torch.zeros((1, target_dim[0], target_dim[1]), dtype=torch.float32)

class ScanSliceSequenceDataset(Dataset):
    def __init__(self, list_of_triplets):
        self.triplets = list_of_triplets

    def __len__(self): return len(self.triplets)

    def __getitem__(self, index):
        prev_path, mid_path, next_path = self.triplets[index]
        prev_slice = load_image_to_tensor(prev_path)
        mid_slice = load_image_to_tensor(mid_path)
        next_slice = load_image_to_tensor(next_path)
        
        model_input_x = torch.cat([prev_slice, next_slice], dim=0) 
        model_target_y = mid_slice
        return model_input_x, model_target_y

# ============================================================
# 4. METRICS & MODELS
# ============================================================

def calculate_psnr(prediction, ground_truth, max_value=1.0):
    mse = F.mse_loss(prediction, ground_truth, reduction="mean")
    if mse.item() == 0: return 99.0 
    return 20 * torch.log10(max_value / torch.sqrt(mse))

# 


class SimpleConvolutionalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = nn.Conv2d(2, 64, 9, padding=4)
        self.non_linear_map = nn.Conv2d(64, 32, 5, padding=2)
        self.reconstruction = nn.Conv2d(32, 1, 5, padding=2)
    def forward(self, x):
        x = F.relu(self.feature_extractor(x))
        x = F.relu(self.non_linear_map(x))
        return torch.sigmoid(self.reconstruction(x))

class DeeperConvolutionalModel(nn.Module):
    def __init__(self):
        super().__init__()
        layer_list = []
        input_channels = 2
        for output_channels in [64, 64, 64, 32]:
            layer_list.append(nn.Conv2d(input_channels, output_channels, 3, padding=1))
            layer_list.append(nn.ReLU(inplace=True))
            input_channels = output_channels
        layer_list.append(nn.Conv2d(input_channels, 1, 3, padding=1))
        self.network = nn.Sequential(*layer_list)
    def forward(self, x): return torch.sigmoid(self.network(x))

# GAN Components
class UNetConvBlock(nn.Module):
    def __init__(self, input_channels, output_channels, is_downsample=True):
        super().__init__()
        if is_downsample:
            self.operation = nn.Sequential(
                nn.Conv2d(input_channels, output_channels, 4, 2, 1),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, True),
            )
        else:
            self.operation = nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, 4, 2, 1),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(True),
            )
    def forward(self, x): return self.operation(x)

# 


class SliceSynthesizer(nn.Module):
    def __init__(self, base_filters=64):
        super().__init__()
        self.down_a = UNetConvBlock(2, base_filters, True)
        self.down_b = UNetConvBlock(base_filters, base_filters * 2, True)
        self.down_c = UNetConvBlock(base_filters * 2, base_filters * 4, True)
        self.deepest_point = nn.Sequential(nn.Conv2d(base_filters * 4, base_filters * 4, 3, 1, 1), nn.ReLU(True))
        self.up_c = UNetConvBlock(base_filters * 4, base_filters * 2, False)
        self.up_b = UNetConvBlock(base_filters * 4, base_filters, False)
        self.up_a = UNetConvBlock(base_filters * 2, base_filters, False)
        self.output_layer = nn.Conv2d(base_filters, 1, 3, padding=1)

    def forward(self, x):
        d_a = self.down_a(x)
        d_b = self.down_b(d_a)
        d_c = self.down_c(d_b)
        b = self.deepest_point(d_c)
        u_c = self.up_c(b); u_c = torch.cat([u_c, d_b], dim=1)
        u_b = self.up_b(u_c); u_b = torch.cat([u_b, d_a], dim=1)
        u_a = self.up_a(u_b)
        return torch.sigmoid(self.output_layer(u_a))

class RealityChecker(nn.Module):
    def __init__(self, input_channels=3, base_filters=64):
        super().__init__()
        layer_list = [
            nn.Conv2d(input_channels, base_filters, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base_filters, base_filters * 2, 4, 2, 1), nn.BatchNorm2d(base_filters * 2), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base_filters * 2, base_filters * 4, 4, 2, 1), nn.BatchNorm2d(base_filters * 4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base_filters * 4, 1, 3, 1, 1),
        ]
        self.network = nn.Sequential(*layer_list)
    def forward(self, input_slices, target_or_fake_slice):
        return self.network(torch.cat([input_slices, target_or_fake_slice], dim=1))

# ============================================================
# 5. TRAINING & VALIDATION LOOPS
# ============================================================

def train_sr_model_one_epoch(model, data_loader, optimizer, precision_scaler, hardware_device):
    model.train()
    total_loss = 0.0
    total_psnr = 0.0
    batch_count = 0

    for input_x, target_y in tqdm(data_loader, desc="Train SR", leave=False):
        input_x = input_x.to(hardware_device)
        target_y = target_y.to(hardware_device)
        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type=amp_dev_type, enabled=torch.cuda.is_available()):
            prediction = model(input_x)
            loss_value = F.l1_loss(prediction, target_y)

        precision_scaler.scale(loss_value).backward()
        precision_scaler.step(optimizer)
        precision_scaler.update()

        total_loss += loss_value.item()
        total_psnr += calculate_psnr(prediction.detach(), target_y.detach()).item()
        batch_count += 1

    return total_loss / batch_count, total_psnr / batch_count

@torch.no_grad()
def validate_sr_model(model, data_loader, hardware_device):
    model.eval()
    total_loss = 0.0
    total_psnr = 0.0
    batch_count = 0
    for input_x, target_y in tqdm(data_loader, desc="Val SR", leave=False):
        input_x = input_x.to(hardware_device)
        target_y = target_y.to(hardware_device)
        with autocast(device_type=amp_dev_type, enabled=torch.cuda.is_available()):
            prediction = model(input_x)
            loss_value = F.l1_loss(prediction, target_y)
        total_loss += loss_value.item()
        total_psnr += calculate_psnr(prediction, target_y).item()
        batch_count += 1
    return total_loss / batch_count, total_psnr / batch_count

def train_gan_model_one_epoch(Generator, Discriminator, data_loader, opt_G, opt_D, scaler_G, scaler_D, hardware_device):
    Generator.train()
    Discriminator.train()
    bce_loss_fn = nn.BCEWithLogitsLoss() 
    total_G_loss = 0.0
    total_D_loss = 0.0
    total_psnr = 0.0
    batch_count = 0

    for input_x, target_y in tqdm(data_loader, desc="Train GAN", leave=False):
        input_x = input_x.to(hardware_device)
        target_y = target_y.to(hardware_device)

        opt_D.zero_grad(set_to_none=True)
        with autocast(device_type=amp_dev_type, enabled=torch.cuda.is_available()):
            fake_image = Generator(input_x).detach() 
            pred_on_real = Discriminator(input_x, target_y) 
            pred_on_fake = Discriminator(input_x, fake_image)
            loss_D = 0.5 * (bce_loss_fn(pred_on_real, torch.ones_like(pred_on_real)) + bce_loss_fn(pred_on_fake, torch.zeros_like(pred_on_fake)))

        scaler_D.scale(loss_D).backward()
        scaler_D.step(opt_D)
        scaler_D.update()

        opt_G.zero_grad(set_to_none=True)
        with autocast(device_type=amp_dev_type, enabled=torch.cuda.is_available()):
            fake_image = Generator(input_x)
            pred_on_fake_for_G = Discriminator(input_x, fake_image) 
            loss_G = bce_loss_fn(pred_on_fake_for_G, torch.ones_like(pred_on_fake_for_G)) + F.l1_loss(fake_image, target_y) * 100.0

        scaler_G.scale(loss_G).backward()
        scaler_G.step(opt_G)
        scaler_G.update()

        total_G_loss += loss_G.item()
        total_D_loss += loss_D.item()
        total_psnr += calculate_psnr(fake_image.detach(), target_y.detach()).item()
        batch_count += 1

    return (total_G_loss/batch_count, total_D_loss/batch_count, total_psnr/batch_count)

@torch.no_grad()
def validate_gan_generator(Generator, data_loader, hardware_device):
    Generator.eval()
    total_l1 = 0.0
    total_psnr = 0.0
    batch_count = 0
    for input_x, target_y in tqdm(data_loader, desc="Val GAN", leave=False):
        input_x = input_x.to(hardware_device)
        target_y = target_y.to(hardware_device)
        with autocast(device_type=amp_dev_type, enabled=torch.cuda.is_available()):
            fake_image = Generator(input_x)
            l1 = F.l1_loss(fake_image, target_y)
        total_l1 += l1.item()
        total_psnr += calculate_psnr(fake_image, target_y).item()
        batch_count += 1
    return total_l1 / batch_count, total_psnr / batch_count

# ============================================================
# 6. MAIN EXECUTION FLOW
# ============================================================

def run_interpolation_project():
    start_time = time.time()
    
    # 1. Prepare Data
    print("\n--- 1. Robust Data Preparation ---")
    train_triplets, val_triplets, test_triplets = process_and_organize_data()

    training_dataset = ScanSliceSequenceDataset(train_triplets)
    validation_dataset = ScanSliceSequenceDataset(val_triplets)
    testing_dataset = ScanSliceSequenceDataset(test_triplets)

    train_loader = DataLoader(training_dataset, batch_size=BATCH_SIZE_PER_STEP, shuffle=True, num_workers=DATA_WORKER_THREADS, pin_memory=True)
    val_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE_PER_STEP, shuffle=False, num_workers=DATA_WORKER_THREADS, pin_memory=True)
    test_loader = DataLoader(testing_dataset, batch_size=BATCH_SIZE_PER_STEP, shuffle=False, num_workers=DATA_WORKER_THREADS, pin_memory=True)

    all_model_histories = {}

    # --- Train Standard Models ---
    base_sr_networks = {"SimpleConvolutionalModel": SimpleConvolutionalModel(), "DeeperConvolutionalModel": DeeperConvolutionalModel()}
    sr_networks = {}
    for name, m in base_sr_networks.items():
        m = m.to(processing_device)
        if USE_MULTI_GPU: m = nn.DataParallel(m)
        sr_networks[name] = m

    for name, model in sr_networks.items():
        print(f"\n===== Training {name} =====")
        optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
        scaler = GradScaler(amp_dev_type, enabled=torch.cuda.is_available())
        history = defaultdict(list)
        best_psnr_score = -1.0

        for epoch in range(1, TOTAL_TRAINING_EPOCHS + 1):
            print(f"--- Epoch {epoch}/{TOTAL_TRAINING_EPOCHS} ---")
            train_l, train_p = train_sr_model_one_epoch(model, train_loader, optimizer, scaler, processing_device)
            val_l, val_p = validate_sr_model(model, val_loader, processing_device)

            history["train_loss"].append(train_l); history["train_psnr"].append(train_p)
            history["val_loss"].append(val_l); history["val_psnr"].append(val_p)
            print(f"  Train: L1={train_l:.4f}, PSNR={train_p:.2f} dB | Val: L1={val_l:.4f}, PSNR={val_p:.2f} dB")

            if val_p > best_psnr_score:
                best_psnr_score = val_p
                torch.save(extract_model_weights(model), f"best_{name}.pth")
        
        all_model_histories[name] = history

    # --- Train GAN ---
    print("\n===== Training GAN =====")
    Generator = SliceSynthesizer().to(processing_device)
    Discriminator = RealityChecker().to(processing_device)
    if USE_MULTI_GPU:
        Generator = nn.DataParallel(Generator)
        Discriminator = nn.DataParallel(Discriminator)

    opt_G = torch.optim.Adam(Generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    opt_D = torch.optim.Adam(Discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    scaler_G = GradScaler(amp_dev_type, enabled=torch.cuda.is_available())
    scaler_D = GradScaler(amp_dev_type, enabled=torch.cuda.is_available())

    gan_history = defaultdict(list)
    best_gan_psnr_score = -1.0

    for epoch in range(1, TOTAL_TRAINING_EPOCHS + 1):
        print(f"--- Epoch {epoch}/{TOTAL_TRAINING_EPOCHS} ---")
        tg_l, td_l, tp_g = train_gan_model_one_epoch(Generator, Discriminator, train_loader, opt_G, opt_D, scaler_G, scaler_D, processing_device)
        vl1_g, vp_g = validate_gan_generator(Generator, val_loader, processing_device)

        gan_history["train_G_loss"].append(tg_l); gan_history["train_D_loss"].append(td_l); gan_history["train_psnr"].append(tp_g)
        gan_history["val_l1"].append(vl1_g); gan_history["val_psnr"].append(vp_g)
        print(f"  Train: G={tg_l:.4f}, D={td_l:.4f}, PSNR={tp_g:.2f} dB | Val: L1={vl1_g:.4f}, PSNR={vp_g:.2f} dB")

        if vp_g > best_gan_psnr_score:
            best_gan_psnr_score = vp_g
            torch.save(extract_model_weights(Generator), "best_SliceSynthesizer_G.pth")
            torch.save(extract_model_weights(Discriminator), "best_SliceSynthesizer_D.pth")

    all_model_histories["SliceSynthesizer"] = gan_history
    print("\nTraining complete.")

if __name__ == "__main__":
    run_interpolation_project()



PyTorch Version: 2.6.0+cu124
Using hardware: cuda | Available GPUs: 2
✅ Parallel Processing enabled across 2 GPUs

--- 1. Robust Data Preparation ---
Total patients: 842 | Train=589, Val=126, Test=127


Preparing Data:   0%|          | 0/842 [00:00<?, ?it/s]

  resized_pil_image = pil_image.resize((target_dim[1], target_dim[0]), resample=Image.BILINEAR)


⚠️ Found corrupt image: Prostate-MRI-US-Biopsy-0479_slice_020.png. Skipping associated triplets.
⚠️ Found corrupt image: Prostate-MRI-US-Biopsy-0479_slice_020.png. Skipping associated triplets.
⚠️ Found corrupt image: Prostate-MRI-US-Biopsy-0479_slice_020.png. Skipping associated triplets.
Data Ready: Train=41165, Val=7913, Test=8592

===== Training SimpleConvolutionalModel =====




--- Epoch 1/20 ---




Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0645, PSNR=20.01 dB | Val: L1=0.0511, PSNR=24.32 dB
--- Epoch 2/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0582, PSNR=20.39 dB | Val: L1=0.0472, PSNR=25.62 dB
--- Epoch 3/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0576, PSNR=20.43 dB | Val: L1=0.0474, PSNR=25.28 dB
--- Epoch 4/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0573, PSNR=20.44 dB | Val: L1=0.0468, PSNR=25.72 dB
--- Epoch 5/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0571, PSNR=20.49 dB | Val: L1=0.0464, PSNR=25.75 dB
--- Epoch 6/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0570, PSNR=20.48 dB | Val: L1=0.0466, PSNR=25.57 dB
--- Epoch 7/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0568, PSNR=20.49 dB | Val: L1=0.0461, PSNR=25.86 dB
--- Epoch 8/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
  File "/apps/jupyter/6.5.4/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0567, PSNR=20.48 dB | Val: L1=0.0460, PSNR=25.81 dB
--- Epoch 9/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0566, PSNR=20.51 dB | Val: L1=0.0460, PSNR=25.85 dB
--- Epoch 10/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0565, PSNR=20.53 dB | Val: L1=0.0457, PSNR=25.96 dB
--- Epoch 11/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0564, PSNR=20.54 dB | Val: L1=0.0459, PSNR=25.79 dB
--- Epoch 12/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
  File "/apps/jupyter/6.5.4/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0563, PSNR=20.53 dB | Val: L1=0.0460, PSNR=25.74 dB
--- Epoch 13/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0562, PSNR=20.56 dB | Val: L1=0.0457, PSNR=25.84 dB
--- Epoch 14/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0562, PSNR=20.57 dB | Val: L1=0.0467, PSNR=25.79 dB
--- Epoch 15/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0561, PSNR=20.58 dB | Val: L1=0.0454, PSNR=25.98 dB
--- Epoch 16/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
  File "/apps/jupyter/6.5.4/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0561, PSNR=20.58 dB | Val: L1=0.0456, PSNR=25.84 dB
--- Epoch 17/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0560, PSNR=20.60 dB | Val: L1=0.0472, PSNR=25.23 dB
--- Epoch 18/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0559, PSNR=20.63 dB | Val: L1=0.0458, PSNR=25.96 dB
--- Epoch 19/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0559, PSNR=20.62 dB | Val: L1=0.0456, PSNR=26.01 dB
--- Epoch 20/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
  File "/apps/jupyter/6.5.4/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0558, PSNR=20.63 dB | Val: L1=0.0454, PSNR=26.03 dB

===== Training DeeperConvolutionalModel =====
--- Epoch 1/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0616, PSNR=20.17 dB | Val: L1=0.0469, PSNR=25.58 dB
--- Epoch 2/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>Exception ignored in: 
self._shutdown_workers()<function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>Exception ignored in: Traceback (most recent call last):


<function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
Traceback (most recent call last):

          File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Traceback (most 

  Train: L1=0.0572, PSNR=20.43 dB | Val: L1=0.0463, PSNR=25.79 dB
--- Epoch 3/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0568, PSNR=20.47 dB | Val: L1=0.0471, PSNR=25.40 dB
--- Epoch 4/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0566, PSNR=20.53 dB | Val: L1=0.0472, PSNR=25.27 dB
--- Epoch 5/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0564, PSNR=20.53 dB | Val: L1=0.0459, PSNR=25.78 dB
--- Epoch 6/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
  File "/apps/jupyter/6.5.4/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0562, PSNR=20.56 dB | Val: L1=0.0457, PSNR=25.83 dB
--- Epoch 7/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0561, PSNR=20.57 dB | Val: L1=0.0457, PSNR=25.99 dB
--- Epoch 8/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0559, PSNR=20.60 dB | Val: L1=0.0453, PSNR=25.91 dB
--- Epoch 9/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
  File "/apps/jupyter/6.5.4/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.

  Train: L1=0.0557, PSNR=20.62 dB | Val: L1=0.0454, PSNR=26.01 dB
--- Epoch 10/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0555, PSNR=20.66 dB | Val: L1=0.0449, PSNR=26.04 dB
--- Epoch 11/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0554, PSNR=20.72 dB | Val: L1=0.0450, PSNR=25.93 dB
--- Epoch 12/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

IOStream.flush timed out
IOStream.flush timed out
Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820><function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820><function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>

Traceback (most recent call last):
Traceback (most recent call last):

  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Traceback (most recent call last):
          File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
self._shutdown_workers()self._shutdown_workers()    
self._shutdown_workers()
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
  File 

  Train: L1=0.0552, PSNR=20.74 dB | Val: L1=0.0452, PSNR=26.15 dB
--- Epoch 13/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0550, PSNR=20.77 dB | Val: L1=0.0459, PSNR=25.29 dB
--- Epoch 14/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0549, PSNR=20.77 dB | Val: L1=0.0454, PSNR=25.39 dB
--- Epoch 15/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0547, PSNR=20.80 dB | Val: L1=0.0449, PSNR=25.55 dB
--- Epoch 16/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
<function _MultiProcessingDataLoaderIter.__del__ at 0x14bd970b4820>
Traceback (most recent call last):

Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Traceback (most recent call last):
  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()    self._shutdown_workers()
self._shutdown_workers()  File "/home/anany.sharma/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers


  File "/home/anany.sharma/.local/lib/python3.10/site-pa

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0546, PSNR=20.83 dB | Val: L1=0.0445, PSNR=26.11 dB
--- Epoch 17/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0545, PSNR=20.84 dB | Val: L1=0.0442, PSNR=25.85 dB
--- Epoch 18/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0544, PSNR=20.86 dB | Val: L1=0.0443, PSNR=26.15 dB
--- Epoch 19/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0543, PSNR=20.88 dB | Val: L1=0.0438, PSNR=26.05 dB
--- Epoch 20/20 ---


Train SR:   0%|          | 0/5146 [00:00<?, ?it/s]

Val SR:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: L1=0.0542, PSNR=20.88 dB | Val: L1=0.0445, PSNR=25.75 dB

===== Training GAN =====
--- Epoch 1/20 ---


Train GAN:   0%|          | 0/5146 [00:00<?, ?it/s]

Val GAN:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: G=7.8721, D=0.4947, PSNR=19.59 dB | Val: L1=0.0545, PSNR=23.10 dB
--- Epoch 2/20 ---


Train GAN:   0%|          | 0/5146 [00:00<?, ?it/s]

Val GAN:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: G=7.2847, D=0.5422, PSNR=19.96 dB | Val: L1=0.0534, PSNR=23.86 dB
--- Epoch 3/20 ---


Train GAN:   0%|          | 0/5146 [00:00<?, ?it/s]

Val GAN:   0%|          | 0/990 [00:00<?, ?it/s]

  Train: G=7.1823, D=0.5540, PSNR=20.00 dB | Val: L1=0.0505, PSNR=24.67 dB
--- Epoch 4/20 ---


Train GAN:   0%|          | 0/5146 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



  Train: G=6.8810, D=0.5704, PSNR=20.33 dB | Val: L1=0.0488, PSNR=24.75 dB
--- Epoch 16/20 ---


Train GAN:   0%|          | 0/5146 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [4]:
import torch
print("PyTorch Version:", torch.__version__)
print("CUDA Version PyTorch uses:", torch.version.cuda)

if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
    print("GPU Capability Score:", torch.cuda.get_device_capability(0))
else:
    print("No GPU detected!")

PyTorch Version: 2.6.0+cu124
CUDA Version PyTorch uses: 12.4
GPU Name: NVIDIA L4
GPU Capability Score: (8, 9)
