<a href="https://colab.research.google.com/github/ArsipKodeId/ArsipKodeId.github.io/blob/main/GenAI_GLI_02D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install monai

Collecting monai
  Downloading monai-1.5.1-py3-none-any.whl.metadata (13 kB)
Downloading monai-1.5.1-py3-none-any.whl (2.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.5.1


In [2]:
import os

def get_modality_paths(patient_folder_path):
  """
  Generates a dictionary of full file paths for MRI modalities and segmentation masks
  for a given patient folder.

  Args:
    patient_folder_path (str): The path to the patient's data folder.

  Returns:
    dict: A dictionary where keys are modality names (t1c, t1n, t2f, t2w, seg)
          and values are their corresponding full file paths.
  """
  modalities = ['t1c', 't1n', 't2f', 't2w', 'seg']
  modality_paths = {}

  # Extract patient_id from the patient_folder_path
  patient_id = os.path.basename(patient_folder_path)

  for modality in modalities:
    file_name = f"{patient_id}-{modality}.nii.gz"
    full_path = os.path.join(patient_folder_path, file_name)
    modality_paths[modality] = full_path

  return modality_paths

print("The 'get_modality_paths' function has been defined.")

The 'get_modality_paths' function has been defined.


In [3]:
import glob

tumor_img_start = 15
tumor_img_end = 111

root_path_50 = "/content/drive/MyDrive/Datasets/BraTS2023_GLI_Challenge/training_50"

In [4]:
from monai.transforms import HistogramNormalize
import nibabel as nib # Added for self-containment
import numpy as np



In [5]:
folders = glob.glob(f"{root_path_50}/***")
len(folders)

51

In [6]:
import torch
from torch.utils.data import Dataset

class BraTSDataset(Dataset):
  def __init__(self, patient_folders):
    self.patient_folders = patient_folders
    print(f"Initialized BraTSDataset with {len(self.patient_folders)} patient folders.")

  def __len__(self):
    return len(self.patient_folders)

  def __getitem__(self, index):
    # Placeholder for loading and preprocessing data for a single patient
    # This will be implemented in subsequent steps.
    patient_folder_path = self.patient_folders[index]
    # For now, just return the path to demonstrate it's working
    return patient_folder_path

print("The BraTSDataset class has been defined.")

The BraTSDataset class has been defined.


In [7]:
import torch
from torch.utils.data import Dataset
from monai.transforms import HistogramNormalize, ScaleIntensityRange
import nibabel as nib
import numpy as np

class BraTSDataset(Dataset):
  def __init__(self, patient_folders):
    self.patient_folders = patient_folders
    self.histogram_normalize = HistogramNormalize(num_bins=256)
    self.num_slices_per_volume = tumor_img_end - tumor_img_start
    print(f"Initialized BraTSDataset with {len(self.patient_folders)} patient folders.")

  def __len__(self):
    # Now returns the total number of 2D slices across all patients
    return len(self.patient_folders) * self.num_slices_per_volume

  def __getitem__(self, index):
    # Determine which patient and which slice within that patient corresponds to the index
    patient_idx = index // self.num_slices_per_volume
    slice_in_volume_idx = index % self.num_slices_per_volume
    actual_slice_idx = tumor_img_start + slice_in_volume_idx

    patient_folder_path = self.patient_folders[patient_idx]
    modality_paths = get_modality_paths(patient_folder_path)

    # Load NIfTI images for each modality
    t1c_img = nib.load(modality_paths['t1c']).get_fdata()
    t1n_img = nib.load(modality_paths['t1n']).get_fdata()
    t2f_img = nib.load(modality_paths['t2f']).get_fdata()
    t2w_img = nib.load(modality_paths['t2w']).get_fdata()

    # Convert numpy arrays to torch tensors
    t1c_tensor = torch.from_numpy(t1c_img).float()
    t1n_tensor = torch.from_numpy(t1n_img).float()
    t2f_tensor = torch.from_numpy(t2f_img).float()
    t2w_tensor = torch.from_numpy(t2w_img).float()

    # Apply HistogramNormalize (Monai's HistogramNormalize expects channel-first, so unsqueeze for 3D image)
    t1c_normalized = self.histogram_normalize(t1c_tensor.unsqueeze(0)).squeeze(0)
    t1n_normalized = self.histogram_normalize(t1n_tensor.unsqueeze(0)).squeeze(0)
    t2f_normalized = self.histogram_normalize(t2f_tensor.unsqueeze(0)).squeeze(0)
    t2w_normalized = self.histogram_normalize(t2w_tensor.unsqueeze(0)).squeeze(0)

    # Apply ScaleIntensityRange for min-max scaling to [0, 1] for each modality
    def monai_min_max_scale(tensor):
      min_val = tensor.min()
      max_val = tensor.max()
      if max_val - min_val == 0:
        return torch.zeros_like(tensor)
      # Create ScaleIntensityRange for this specific tensor's min/max
      scale_transform = ScaleIntensityRange(a_min=min_val, a_max=max_val, b_min=0.0, b_max=1.0, clip=True)
      return scale_transform(tensor)

    t1c_scaled = monai_min_max_scale(t1c_normalized)
    t1n_scaled = monai_min_max_scale(t1n_normalized)
    t2f_scaled = monai_min_max_scale(t2f_normalized)
    t2w_scaled = monai_min_max_scale(t2w_normalized)

    # Extract a single axial slice at actual_slice_idx
    t1c_slice = t1c_scaled[:, :, actual_slice_idx]
    t1n_slice = t1n_scaled[:, :, actual_slice_idx]
    t2f_slice = t2f_scaled[:, :, actual_slice_idx]
    t2w_slice = t2w_scaled[:, :, actual_slice_idx]

    # Stack preprocessed MRI modalities into an input tensor (t1n, t2f, t2w)
    # Input tensor will have shape (C, H, W) where C=3 for 3 modalities
    input_tensor = torch.stack([t1n_slice, t2f_slice, t2w_slice], dim=0)

    # Target tensor for t1c (image to image translation)
    # Add channel dimension to the 2D slice
    target_tensor = t1c_slice.unsqueeze(0)

    return input_tensor, target_tensor

print("The BraTSDataset class has been updated to yield individual 2D axial slices for image-to-image translation.")

The BraTSDataset class has been updated to yield individual 2D axial slices for image-to-image translation.


In [18]:
from torch.utils.data import DataLoader

# Instantiate the BraTSDataset
brats_dataset = BraTSDataset(folders[:25])

# Create a DataLoader
batch_size = 1 # For demonstration, we'll use a batch size of 1
data_loader = DataLoader(brats_dataset, batch_size=batch_size, shuffle=False)

print(f"Initialized DataLoader with {len(brats_dataset)} samples and batch size {batch_size}.")

# Get a single sample from the data loader to verify
for i, (input_tensor, target_tensor) in enumerate(data_loader):
  if i == 0:
    print(f"Shape of input tensor: {input_tensor.shape}")
    print(f"Shape of target tensor: {target_tensor.shape}")
    break


Initialized BraTSDataset with 25 patient folders.
Initialized DataLoader with 2400 samples and batch size 1.
Shape of input tensor: torch.Size([1, 3, 240, 240])
Shape of target tensor: torch.Size([1, 1, 240, 240])


In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
  """Helper block for two convolutional layers with Batch Normalization and ReLU."""
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.double_conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

  def forward(self, x):
    return self.double_conv(x)


class Down(nn.Module):
  """Downsampling block with MaxPool and DoubleConv."""
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.maxpool_conv = nn.Sequential(
        nn.MaxPool2d(2),
        DoubleConv(in_channels, out_channels)
    )

  def forward(self, x):
    return self.maxpool_conv(x)


class Up(nn.Module):
  """Upsampling block with TransposedConv, concatenation, and DoubleConv."""
  def __init__(self, in_channels, out_channels, bilinear=False):
    super().__init__()

    # if bilinear, use the normal convolutions to reduce the number of channels
    if bilinear:
      self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
      self.conv = DoubleConv(in_channels, out_channels)
    else:
      self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
      self.conv = DoubleConv(in_channels, out_channels)

  def forward(self, x1, x2):
    x1 = self.up(x1)
    # input is CHW
    diffY = x2.size()[2] - x1.size()[2]
    diffX = x2.size()[3] - x1.size()[3]

    x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                    diffY // 2, diffY - diffY // 2])
    # If you have padding issues, see
    # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
    # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
    x = torch.cat([x2, x1], dim=1)
    return self.conv(x)


class UNet(nn.Module):
  """Full U-Net architecture for image segmentation or image-to-image translation."""
  def __init__(self, n_channels, n_classes, bilinear=False):
    super(UNet, self).__init__()
    self.n_channels = n_channels
    self.n_classes = n_classes
    self.bilinear = bilinear

    # Encoder (Downsampling path)
    self.inc = DoubleConv(n_channels, 64)
    self.down1 = Down(64, 128)
    self.down2 = Down(128, 256)
    self.down3 = Down(256, 512)
    factor = 2 if bilinear else 1
    self.down4 = Down(512, 1024 // factor)

    # Decoder (Upsampling path)
    self.up1 = Up(1024, 512 // factor, bilinear)
    self.up2 = Up(512, 256 // factor, bilinear)
    self.up3 = Up(256, 128 // factor, bilinear)
    self.up4 = Up(128, 64, bilinear)
    self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

  def forward(self, x):
    x1 = self.inc(x)
    x2 = self.down1(x1)
    x3 = self.down2(x2)
    x4 = self.down3(x3)
    x5 = self.down4(x4)

    x = self.up1(x5, x4) # Concatenate with skip connection x4
    x = self.up2(x, x3)  # Concatenate with skip connection x3
    x = self.up3(x, x2)  # Concatenate with skip connection x2
    x = self.up4(x, x1)  # Concatenate with skip connection x1
    logits = self.outc(x)
    return logits

print("UNet model for Generator defined successfully.")

UNet model for Generator defined successfully.


In [21]:
# Instantiate the UNet Generator
# Input channels: 3 (t1n, t2f, t2w)
# Output channels: 1 (t1c)

generator = UNet(n_channels=3, n_classes=1)

print("UNet Generator instantiated.")

# Get a single batch from the DataLoader
for i, (input_tensor, target_tensor) in enumerate(data_loader):
  if i == 0:
    # Move tensors to the same device as the model if using GPU
    # For now, assuming CPU
    # input_tensor = input_tensor.to(device)
    # target_tensor = target_tensor.to(device)

    # Pass the input through the generator
    output_tensor = generator(input_tensor)

    print(f"Input tensor shape: {input_tensor.shape}")
    print(f"Output tensor shape from Generator: {output_tensor.shape}")
    print(f"Target tensor shape: {target_tensor.shape}")
    break

print("Generator output shape verified.")

UNet Generator instantiated.
Input tensor shape: torch.Size([1, 3, 240, 240])
Output tensor shape from Generator: torch.Size([1, 1, 240, 240])
Target tensor shape: torch.Size([1, 1, 240, 240])
Generator output shape verified.


In [22]:
import torch.nn as nn

class Discriminator(nn.Module):
  """PatchGAN Discriminator architecture."""
  def __init__(self, in_channels=4, features=[64, 128, 256, 512]):
    super().__init__()
    # The discriminator takes both input image (3 channels) and target image (1 channel)
    # concatenated as input, so in_channels will be 3+1 = 4.
    self.initial_block = nn.Sequential(
        nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
    )

    layers = []
    # Downsampling layers
    for i in range(len(features) - 1):
        in_f = features[i]
        out_f = features[i+1]
        layers += [
            nn.Conv2d(in_f, out_f, kernel_size=4, stride=2 if i != len(features) - 2 else 1, padding=1, bias=False),
            nn.BatchNorm2d(out_f),
            nn.LeakyReLU(0.2, inplace=True),
        ]

    # Final output layer to produce a 1-channel output (probability map)
    layers += [
        nn.Conv2d(features[-1], 1, kernel_size=4, stride=1, padding=1, bias=False)
    ]

    self.model = nn.Sequential(*layers)

  def forward(self, x):
    x = self.initial_block(x)
    return self.model(x)

print("PatchGAN Discriminator architecture defined successfully.")

PatchGAN Discriminator architecture defined successfully.


In [23]:
# Instantiate the PatchGAN Discriminator
# Input channels: 3 (from input modalities) + 1 (from target/generated t1c) = 4
discriminator = Discriminator(in_channels=4)

print("PatchGAN Discriminator instantiated.")

# Get a single batch from the DataLoader to test the discriminator
for i, (input_tensor, target_tensor) in enumerate(data_loader):
  if i == 0:
    # Concatenate the input_tensor and target_tensor along the channel dimension
    # input_tensor has shape (batch_size, 3, H, W)
    # target_tensor has shape (batch_size, 1, H, W)
    # Discriminator input should have shape (batch_size, 4, H, W)
    discriminator_input = torch.cat([input_tensor, target_tensor], dim=1)

    # Pass the concatenated input through the discriminator
    discriminator_output = discriminator(discriminator_input)

    print(f"Discriminator input tensor shape: {discriminator_input.shape}")
    print(f"Discriminator output tensor shape: {discriminator_output.shape}")
    break

print("Discriminator output shape verified.")

PatchGAN Discriminator instantiated.
Discriminator input tensor shape: torch.Size([1, 4, 240, 240])
Discriminator output tensor shape: torch.Size([1, 1, 28, 28])
Discriminator output shape verified.


In [24]:
import torch.nn as nn

# 2. Instantiate L1Loss
l1_loss = nn.L1Loss()

# 3. Instantiate BCEWithLogitsLoss for adversarial loss
adversarial_loss = nn.BCEWithLogitsLoss()

# 4. Instantiate MSELoss
mse_loss = nn.MSELoss()

print("L1 Loss, Adversarial Loss (BCEWithLogitsLoss), and MSE Loss instantiated successfully.")

L1 Loss, Adversarial Loss (BCEWithLogitsLoss), and MSE Loss instantiated successfully.


In [25]:
from monai.metrics import PSNRMetric, SSIMMetric
import torch

# Instantiate the metric classes once
# PSNRMetric requires 'max_val' to be specified during initialization.
psnr_metric_calculator = PSNRMetric(max_val=1.0)
# Removed 'size_average', 'gaussian_weights', 'K1', 'K2', 'kernel_size', 'sigma' as they're causing TypeErrors
ssim_metric_calculator = SSIMMetric(spatial_dims=2)

def calculate_psnr(predicted_image, target_image):
  """
  Calculates the Peak Signal-to-Noise Ratio (PSNR) between two PyTorch tensors.
  Args:
    predicted_image (torch.Tensor): The predicted image tensor (B, C, H, W).
    target_image (torch.Tensor): The target image tensor (B, C, H, W).
  Returns:
    torch.Tensor: The PSNR value.
  """
  # PSNRMetric expects lists of tensors
  return psnr_metric_calculator([predicted_image], [target_image])

def calculate_ssim(predicted_image, target_image):
  """
  Calculates the Structural Similarity Index Measure (SSIM) between two PyTorch tensors.
  Args:
    predicted_image (torch.Tensor): The predicted image tensor (B, C, H, W).
    target_image (torch.Tensor): The target image tensor (B, C, H, W).
  Returns:
    torch.Tensor: The SSIM value.
  """
  # SSIMMetric expects lists of tensors, and 'data_range' is no longer passed as it caused a TypeError
  # Assuming images are already normalized to [0, 1]
  return ssim_metric_calculator([predicted_image], [target_image])

print("PSNR and SSIM calculation functions defined successfully using MONAI Metric classes.")

PSNR and SSIM calculation functions defined successfully using MONAI Metric classes.


In [26]:
import torch.optim as optim

# Learning rates (these are typical starting points for GANs)
lr_g = 2e-4  # Learning rate for the Generator
lr_d = 2e-4  # Learning rate for the Discriminator

# Adam optimizers for Generator and Discriminator
optimizer_g = optim.Adam(generator.parameters(), lr=lr_g, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.5, 0.999))

print("Adam optimizers for Generator and Discriminator initialized successfully.")

Adam optimizers for Generator and Discriminator initialized successfully.


In [27]:
import torch

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Move models to the selected device
generator.to(device)
discriminator.to(device)

# Training parameters
num_epochs = 1
lambda_l1 = 100 # Weight for L1 loss in the Generator's objective (as in Pix2Pix paper)

print(f"Training will be performed on: {device}")
print(f"Number of epochs: {num_epochs}")
print(f"Lambda for L1 loss: {lambda_l1}")

Training will be performed on: cuda
Number of epochs: 1
Lambda for L1 loss: 100


In [28]:
from tqdm import tqdm
import torch.nn.functional as F
import pandas as pd # Import pandas for CSV handling

print("Starting training loop setup.")

# Adjust num_epochs for a more complete training run
num_epochs = 25

# Lists to store losses and metrics for plotting/logging
g_losses = []
d_losses = []
psnr_scores = []
ssim_scores = []

# List to store all epoch results for CSV export
epoch_results = []

for epoch in range(num_epochs):
  print(f"Epoch {epoch+1}/{num_epochs}")

  # Set models to training mode
  generator.train()
  discriminator.train()

  # Initialize epoch-wise loss accumulation
  epoch_g_loss = 0.0
  epoch_d_loss = 0.0
  epoch_gen_adversarial_loss = 0.0
  epoch_gen_l1_loss = 0.0
  num_batches = 0

  # Iterate through the data_loader with tqdm for progress tracking
  for batch_idx, (input_tensor, target_tensor) in enumerate(tqdm(data_loader, desc=f"Epoch {epoch+1}")):
    # Move tensors to the device
    input_tensor = input_tensor.to(device)
    target_tensor = target_tensor.to(device);

    # --- Discriminator training step ---
    optimizer_d.zero_grad()

    # 1. Train with real images
    # Concatenate input_tensor and original target_tensor for real input
    real_discriminator_input = torch.cat([input_tensor, target_tensor], dim=1);
    real_output = discriminator(real_discriminator_input);
    # Create a tensor of ones for real labels
    real_labels = torch.ones_like(real_output).to(device);
    d_real_loss = adversarial_loss(real_output, real_labels);

    # 2. Train with fake images
    # Generate fake target_tensor
    fake_target_tensor = generator(input_tensor);
    # Detach fake_target_tensor to prevent gradients from flowing back to the generator
    fake_discriminator_input = torch.cat([input_tensor, fake_target_tensor.detach()], dim=1);
    fake_output = discriminator(fake_discriminator_input);
    # Create a tensor of zeros for fake labels
    fake_labels = torch.zeros_like(fake_output).to(device);
    d_fake_loss = adversarial_loss(fake_output, fake_labels);

    # 3. Combine losses and update Discriminator
    d_loss = d_real_loss + d_fake_loss;
    d_loss.backward();
    optimizer_d.step();
    # --- End Discriminator training step ---

    # --- Generator training step ---
    optimizer_g.zero_grad()

    # Generate fake target_tensor (already done above, but we need gradients now)
    fake_target_tensor_for_gen = generator(input_tensor);

    # 1. Adversarial loss for Generator (Generator wants discriminator to think fakes are real)
    # Discriminator input with fake images for generator loss calculation
    gen_discriminator_input = torch.cat([input_tensor, fake_target_tensor_for_gen], dim=1);
    gen_output = discriminator(gen_discriminator_input);
    # Generator wants discriminator to output ones for fake images
    gen_adversarial_loss = adversarial_loss(gen_output, real_labels);

    # 2. L1 loss (reconstruction loss)
    gen_l1_loss = l1_loss(fake_target_tensor_for_gen, target_tensor);

    # 3. Combine losses and update Generator
    g_loss = gen_adversarial_loss + lambda_l1 * gen_l1_loss;
    g_loss.backward();
    optimizer_g.step();
    # --- End Generator training step ---

    epoch_g_loss += g_loss.item();
    epoch_d_loss += d_loss.item();
    epoch_gen_adversarial_loss += gen_adversarial_loss.item();
    epoch_gen_l1_loss += gen_l1_loss.item();
    num_batches += 1;

  # --- Evaluation and Logging after each epoch ---
  generator.eval()
  discriminator.eval() # Discriminator eval mode, though not strictly necessary for metrics

  total_psnr = 0.0
  total_ssim = 0.0
  total_mse_eval_loss = 0.0
  eval_samples = 0

  with torch.no_grad():
    for input_tensor_eval, target_tensor_eval in data_loader:
      input_tensor_eval = input_tensor_eval.to(device);
      target_tensor_eval = target_tensor_eval.to(device);

      generated_image = generator(input_tensor_eval);

      # Fix: Squeeze the batch dimension (N=1) to convert (1, C, H, W) to (C, H, W)
      # This is because MONAI's metric classes internally add a batch dimension for each item in the list.
      generated_image_squeezed = generated_image.squeeze(0);
      target_tensor_eval_squeezed = target_tensor_eval.squeeze(0);

      # Ensure the generated image and target image have the same C, H, W for metrics
      # MONAI metrics expect [N, C, H, W] where N is the internal batch dim they create,
      # so we provide (C, H, W) as the item in the list.
      current_psnr = calculate_psnr(generated_image_squeezed, target_tensor_eval_squeezed);
      current_ssim = calculate_ssim(generated_image_squeezed, target_tensor_eval_squeezed);
      current_mse_loss = mse_loss(generated_image, target_tensor_eval);

      total_psnr += current_psnr.item() * input_tensor_eval.size(0);
      total_ssim += current_ssim.item() * input_tensor_eval.size(0);
      total_mse_eval_loss += current_mse_loss.item() * input_tensor_eval.size(0);
      eval_samples += input_tensor_eval.size(0);

  avg_psnr = total_psnr / eval_samples;
  avg_ssim = total_ssim / eval_samples;
  avg_mse_eval_loss = total_mse_eval_loss / eval_samples;

  # Store losses and metrics
  avg_g_loss = epoch_g_loss / num_batches;
  avg_d_loss = epoch_d_loss / num_batches;
  avg_gen_adversarial_loss = epoch_gen_adversarial_loss / num_batches;
  avg_gen_l1_loss = epoch_gen_l1_loss / num_batches;

  g_losses.append(avg_g_loss);
  d_losses.append(avg_d_loss);
  psnr_scores.append(avg_psnr);
  ssim_scores.append(avg_ssim);

  # Store epoch results for CSV
  epoch_results.append({
      'Epoch': epoch + 1,
      'Avg_Generator_Loss': avg_g_loss,
      'Avg_Discriminator_Loss': avg_d_loss,
      'Avg_Gen_Adversarial_Loss': avg_gen_adversarial_loss,
      'Avg_Gen_L1_Loss': avg_gen_l1_loss,
      'Avg_MSE_Loss_Eval': avg_mse_eval_loss,
      'Avg_PSNR': avg_psnr,
      'Avg_SSIM': avg_ssim
  });

  print(f"Epoch {epoch+1} Results:")
  print(f"  Avg Generator Loss: {avg_g_loss:.4f}")
  print(f"  Avg Discriminator Loss: {avg_d_loss:.4f}")
  print(f"  Avg Gen Adversarial Loss: {avg_gen_adversarial_loss:.4f}")
  print(f"  Avg Gen L1 Loss: {avg_gen_l1_loss:.4f}")
  print(f"  Avg MSE Loss (Eval): {avg_mse_eval_loss:.4f}")
  print(f"  Avg PSNR: {avg_psnr:.4f}")
  print(f"  Avg SSIM: {avg_ssim:.4f}")
  # --- End Evaluation and Logging ---

print("\nTraining complete.")

# Save training metrics to CSV
df_results = pd.DataFrame(epoch_results);
df_results.to_csv('training_metrics.csv', index=False);
print("Training metrics saved to training_metrics.csv")

# Save the models
torch.save(generator.state_dict(), 'generator_model.pth')
torch.save(discriminator.state_dict(), 'discriminator_model.pth')
print("Generator and Discriminator models saved.")

Starting training loop setup.
Epoch 1/25


Epoch 1:  25%|██▍       | 595/2400 [25:51<1:18:25,  2.61s/it]


KeyboardInterrupt: 