# Develop a light GAN for post processing
* The input must be the biggest size in the dataset
* Both Generator and Discriminator must fit in the GPU (48GB VRAM)


### Check biggest (384,384,384 although there is only 1 case with this shape...)

In [None]:
# Find the bigest shape
import os
import tifffile as tiff
import numpy as np
import nibabel as nib
all_shapes = []
root_dataset = "/home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop"
for file_name in os.listdir(root_dataset):
    if file_name.endswith('.nii.gz'):
        file_path = os.path.join(root_dataset, file_name)
        img_array = nib.load(file_path).get_fdata()
        all_shapes.append(img_array.shape)
        #if img_array.shape[0]>320:
        print(img_array.shape)

In [None]:
x_all_shapes = [shape_element[0] for shape_element in all_shapes]
y_all_shapes = [shape_element[1] for shape_element in all_shapes]
z_all_shapes = [shape_element[2] for shape_element in all_shapes]

print(f"Biggest x: {max(x_all_shapes)}")
print(f"Biggest y: {max(y_all_shapes)}")
print(f"Biggest z: {max(z_all_shapes)}")

### Create MedNeXt network

In [None]:
def test_train_step():
    # 1. Setup Device
    if torch.cuda.is_available():
        device = torch.device('cuda')
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
    else:
        device = torch.device('cpu')
        
    print(f"Running on {device}...")

    # 2. Initialize Model, Optimizer, and Scaler (for AMP)
    discriminator = PatchDiscriminator(in_channels=2, initial_filters=16).to(device)
    optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
    criterion = nn.MSELoss() # Least Squares GAN (LSGAN) loss
    scaler = GradScaler()     # Necessary for stable AMP training

    # 3. Create Inputs
    # Input: [B, 2, 320, 320, 320] -> (CT + Predicted_Mask)
    input_tensor = torch.randn(1, 2, 320, 320, 320, device=device)
    
    print(f"Starting training step with AMP...")

    # --- TRAINING STEP START ---
    optimizer.zero_grad()

    # Forward pass with AMP
    with autocast(device_type=device.type):
        output = discriminator(input_tensor)
        print(f"Discriminator Output Shape: {output.shape}")
        
        # CRITICAL FIX: Creates target based on ACTUAL output shape (e.g., 38x38x38)
        # We cannot use a 320x320x320 target for a PatchGAN output.
        target = torch.randn_like(output, device=device) 
        
        loss = criterion(output, target)

    # Backward pass with Scaler
    print("Computing Gradients...")
    scaler.scale(loss).backward()
    
    # Optimizer Step
    scaler.step(optimizer)
    scaler.update()
    # --- TRAINING STEP END ---

    print(f"Step successful. Loss: {loss.item():.4f}")

    # 4. Measure Memory
    if device.type == 'cuda':
        peak_mem = torch.cuda.max_memory_allocated() / (1024 ** 3)
        print(f"Peak Training VRAM: {peak_mem:.2f} GiB")
    else:
        process = psutil.Process(os.getpid())
        print(f"System RAM used: {process.memory_info().rss / (1024 ** 3):.2f} GiB")

if __name__ == "__main__":
    test_train_step()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.amp import autocast, GradScaler
from torch.nn.utils import spectral_norm
def run_test():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Running test on: {device}")

    # A. Create Dummy Data
    # Batch=2, Size=64^3 (Small enough for CPU testing if needed)
    B, D, H, W = 2, 64, 64, 64
    
    # CT Scan (Normalized)
    ct_data = torch.randn(B, 1, D, H, W)
    # Coarse Logits (Simulated output from previous net)
    coarse_logits = torch.randn(B, 1, D, H, W)
    # Ground Truth (Binary Mask 0 or 1)
    gt_mask = torch.randint(0, 2, (B, 1, D, H, W)).float()

    dataset = TensorDataset(ct_data, coarse_logits, gt_mask)
    dataloader = DataLoader(dataset, batch_size=2)

    # B. Initialize Models & Optimizers
    gen = Generator().to(device)
    disc = PatchDiscriminator().to(device)
    
    opt_g = optim.Adam(gen.parameters(), lr=1e-3)
    opt_d = optim.Adam(disc.parameters(), lr=1e-3)
    
    criterion_GAN = nn.MSELoss()
    criterion_BCE = nn.BCEWithLogitsLoss()
    criterion_Tversky = TverskyLoss()
    scaler = GradScaler()

    LAMBDA_SEG = 100.0
    LAMBDA_ADV = 1.0

    print("Setup complete. Starting training loop...")

    # ==========================================
    # 3. The Training Loop Logic
    # ==========================================
    gen.train()
    disc.train()
    
    # Run 3 steps to verify gradients actually change
    initial_loss = None
    
    for step in range(3):
        for real_ct, coarse, real_gt in dataloader:
            real_ct, coarse, real_gt = real_ct.to(device), coarse.to(device), real_gt.to(device)
            
            # --- 1. Discriminator Step ---
            opt_d.zero_grad()
            with autocast(device_type=device.type):
                # Real
                d_input_real = torch.cat([real_ct, real_gt], dim=1)
                d_real = disc(d_input_real)
                loss_d_real = criterion_GAN(d_real, torch.ones_like(d_real))
                
                # Fake
                g_input = torch.cat([real_ct, coarse], dim=1)
                refined_logits = gen(g_input)
                refined_probs = torch.sigmoid(refined_logits)
                
                d_input_fake = torch.cat([real_ct, refined_probs.detach()], dim=1)
                d_fake = disc(d_input_fake)
                loss_d_fake = criterion_GAN(d_fake, torch.zeros_like(d_fake))
                
                loss_d = (loss_d_real + loss_d_fake) * 0.5

            scaler.scale(loss_d).backward()
            scaler.step(opt_d)

            # --- 2. Generator Step ---
            opt_g.zero_grad()
            with autocast(device_type=device.type):
                # Adversarial (Fool D)
                # Note: We re-calculate d_fake WITH gradients flowing to G
                d_input_g = torch.cat([real_ct, torch.sigmoid(refined_logits)], dim=1)
                d_pred = disc(d_input_g)
                loss_g_adv = criterion_GAN(d_pred, torch.ones_like(d_pred))
                
                # Segmentation (Structure)
                loss_g_seg = criterion_BCE(refined_logits, real_gt) + criterion_Tversky(refined_logits, real_gt)
                
                loss_g = (LAMBDA_SEG * loss_g_seg) + (LAMBDA_ADV * loss_g_adv)

            scaler.scale(loss_g).backward()
            scaler.step(opt_g)
            scaler.update()

            print(f"Step {step+1}: D Loss={loss_d.item():.4f}, G Loss={loss_g.item():.4f}")
            
            if step == 0:
                initial_loss = loss_g.item()
            elif step == 2:
                if loss_g.item() != initial_loss:
                    print("\nSUCCESS: Loss is changing. Gradients are flowing.")
                else:
                    print("\nWARNING: Loss is identical. Check for detached gradients.")

if __name__ == "__main__":
    run_test()

### Building the training class

In [None]:

import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
import torch
import json
import sys
from os.path import join
sys.path.append("../utils")
from main_train_class import main_train_STU_Net
from GANs_networks import Generator, PatchDiscriminator
from tqdm import tqdm
from torch.nn.functional import sigmoid, binary_cross_entropy_with_logits
# Standard Library Imports
from os.path import join
import sys

import json

# Third-Party Library Imports
import torch
import torch.optim as optim
from torch.nn.functional import sigmoid
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import autocast, GradScaler
from torch.nn import MSELoss
# MONAI Specific Imports
import monai
from monai.data import CacheDataset

from monai.transforms import (
    Compose,
    CopyItemsd,
    LoadImaged, 
    ScaleIntensityRanged, 
    ResizeWithPadOrCropd,
    EnsureTyped,
    EnsureChannelFirstd
)
from mask_utils import GetROIMaskdd, GetBinaryLabeld
# Local Project Imports

class postprocessConvNeXt(main_train_STU_Net):
    def __init__(self, config):
        # TODO predict and save logits from the segmentation network
        self.config = config
        self.labda_seg = self.config['labda_seg']
        self.G_model = self._build_models()
        self.train_loader = self._set_train_dataloader() 
        self.val_loader = self._set_val_dataloader() 
        self.opt_G = self._set_optimizers()
        self.wandb_run = self._set_wandb_checkpoint() # Heritage
        # Expects lists of predictions # TODO put predictions inside of a list
        self.G_voxel_criterion = self._set_train_criterion() # Heritage
        
        self.val_metric = self._set_val_metric() # Heritage
        self.G_cosAnnealLR = self._set_scheduler()

        # set scaler for mix precision 
        self.G_scaler = GradScaler()
        
        # check resume
        self._resume()

    def _build_models(self):
        # TODO change network
        model = new_ConvNeXt(in_channels=2, first_channels=16, out_channels=1, use_checkpointing=True).to(self.config['device'])
        
        return model
    
    def _resume(self):
        if self.config.get('resume'):
            G_checkpoint = torch.load(self.config['resume'], map_location="cpu", weights_only=False) 
            G_model_weights = G_checkpoint['model_weights']  
            self.G_model.load_state_dict(G_model_weights, strict=True)
            self.G_model = self.G_model.to(self.config['device'])
            # optimizer load
            self.opt_G.load_state_dict(G_checkpoint['optimizer_state_dict'])
            # parameters
            self.start_epoch = G_checkpoint['epoch'] + 1 # To continue to the next epoch instead of repeating  
            self.val_value = G_checkpoint['val_value']
        else:
            self.start_epoch = 0
            self.val_value = 0

    def _set_train_dataloader(self):
        """ Getting the list of cases for training and loading using MONAI (all into memory)"""
        data_list = []
        # TODO: change back to this data list
        #with open(self.config['data_split_json'], "r") as f:
        #    split = json.load(f)
        #train_cases = split["train"]
        #
        train_cases = ["2290837.nii.gz"]
        for train_case in train_cases:
            complete_data_dict = {}
            complete_data_dict["image"] = join(self.config['vol_data_path'], train_case)
            complete_data_dict["gt"] = join(self.config['label_data_path'], train_case)
            complete_data_dict["bridge_weight_map"] = join(self.config['bridge_weight_map_path'], train_case)
            # TODO in the json the pred_seg_logits needs to be changed!
            complete_data_dict["pred_seg_logits"] = join(self.config['pred_seg_logits'], train_case)
            data_list.append(complete_data_dict)
            
            if self.config['debug']:
                for i in range(30):
                    data_list.append(complete_data_dict)
                print(f"training using case: {data_list[0]}")
                break  # repeat 30 cases for debug mode

        print(f"Train cases: {len(train_cases)}")
        print(f"Some examples:")
        print(train_cases[:5])

        transforms_list = [   
                # Load image 
                LoadImaged(keys=["image", 'gt', 'bridge_weight_map', 'pred_seg_logits']),
                EnsureChannelFirstd(keys=["image", 'gt', 'bridge_weight_map', 'pred_seg_logits']),

                # Normalize uint8 input
                ScaleIntensityRanged(keys=["image"], a_min=0, a_max=255, b_min=0, b_max=1, clip=True),

                # Create a ROI mask for cropping 
                GetROIMaskdd(keys=["gt"], ignore_mask_value=2, new_key_names=["roi_mask"]),

                # Cropping or padding if bigger or smaller (expected to be all equall of smaller)
                ResizeWithPadOrCropd(keys=["image", 'gt', 'roi_mask', 'bridge_weight_map', 'pred_seg_logits'], spatial_size=self.config['patch_size'], mode="minimum"),
                GetBinaryLabeld(keys=["gt"], ignore_mask_value=2),
                EnsureTyped(keys=["image", "gt", "roi_mask", "bridge_weight_map", 'pred_seg_logits'], track_meta=False)
        ]

        transforms = Compose(transforms_list)
        
        print("Initializing Dataset...")
        train_ds = CacheDataset(
            data=data_list, 
            transform=transforms, 
            cache_rate=self.config['train_cache_rate'],  
            num_workers=self.config['num_workers'], 
            progress=True
        )

        print("Initializing Train DataLoader...")
        train_loader = monai.data.DataLoader(
            train_ds, 
            batch_size=self.config['batch_size'], 
            num_workers=self.config['num_workers'],
            shuffle=True,      
            pin_memory=True
        )
        return train_loader
    
    def _set_val_dataloader(self):
        data_list = []
        with open(self.config['data_split_json'], "r") as f:
            split = json.load(f)

        val_cases = split["val"]

        if self.config['debug']: 
            print(f"Debug mode: Using training cases for validation dataloader")
            # train_cases = split["train"] # TODO uncomment
            train_cases = ["2290837.nii.gz"] # TODO remove
            for train_case in train_cases:
                complete_data_dict = {}
                complete_data_dict["image"] = join(self.config['vol_data_path'], train_case)
                complete_data_dict["gt"] = join(self.config['label_data_path'], train_case)
                complete_data_dict["bridge_weight_map"] = join(self.config['bridge_weight_map_path'], train_case)
                complete_data_dict["pred_seg_logits"] = join(self.config['pred_seg_logits'], train_case)
                data_list.append(complete_data_dict)
                print(f"Validation using case: {train_case}")
                break  # The same training sample for validation in debug mode
        else:
            for val_case in val_cases:
                complete_data_dict = {}
                complete_data_dict["image"] = join(self.config['vol_data_path'], val_case)
                complete_data_dict["gt"] = join(self.config['label_data_path'], val_case)
                complete_data_dict["bridge_weight_map"] = join(self.config['bridge_weight_map_path'], val_case)
                complete_data_dict["pred_seg_logits"] = join(self.config['pred_seg_logits'], val_case)
                data_list.append(complete_data_dict)

        print(f"Val cases: {len(val_cases)}")
        print(f"Some examples:")
        print(val_cases[:5])

        transforms = Compose(
            [   
                # Load image 
                LoadImaged(keys=["image", 'gt', 'bridge_weight_map', 'pred_seg_logits']),
                EnsureChannelFirstd(keys=["image", 'gt', 'bridge_weight_map', 'pred_seg_logits']),
                # Normalize uint8 input
                ScaleIntensityRanged(keys=["image"], a_min=0, a_max=255, b_min=0, b_max=1, clip=True),
                # Create a ROI mask for cropping 
                GetROIMaskdd(keys=["gt"], ignore_mask_value=2, new_key_names=["roi_mask"]),
                # Get random patches
                ResizeWithPadOrCropd(keys=["image", "gt", "roi_mask", 'bridge_weight_map', 'pred_seg_logits'], spatial_size=self.config['patch_size'], mode="minimum"),
                GetBinaryLabeld(keys=["gt"], ignore_mask_value=2),
                EnsureTyped(keys=["image", "gt", "roi_mask", 'bridge_weight_map', 'pred_seg_logits'], track_meta=False)
            ]
        )

        print("Initializing Dataset...")
        val_ds = CacheDataset(
            data=data_list, 
            transform=transforms, 
            cache_rate=self.config['val_cache_rate'],  
            num_workers=self.config['num_workers'], 
            progress=True
        )
        
        print("Initializing Val DataLoader...")
        val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=self.config['num_workers'])
        return val_loader

    def _set_optimizers(self):
        """Define the optimizer (e.g., Adam, SGD)."""
        opt_g = optim.AdamW(self.G_model.parameters(), lr=self.config['learning_rate'])
        return opt_g
    
    def _set_scheduler(self):
        """Define learning rate scheduler."""
        # If resuming, last_epoch should be start_epoch - 1
        last_epoch = self.config.get('resume_epoch', 0) - 1 if self.config.get('resume') else -1
        G_cosAnnealLR = CosineAnnealingLR(self.opt_G, self.config['num_epochs'], eta_min=self.config['learning_rate']/10, last_epoch=last_epoch)
        return G_cosAnnealLR
    
    def saving_logic(self, best_val_value, val_avg_value, epoch):
        """ Logic to save the best model and periodic checkpoints """

        if best_val_value < val_avg_value: 
            best_val_value = val_avg_value
            G_save_path = join(self.model_save_path, f"model_best.pth")
            torch.save({
                    'epoch': epoch,
                    'model_weights': self.G_model.state_dict(),  
                    'optimizer_state_dict': self.opt_G.state_dict(),
                    'val_value': val_avg_value,
                }, G_save_path)
            print(f"Saved checkpoint: {G_save_path}")
        
        # Save Checkpoint
        if epoch % 10 == 0: 
            G_save_path = join(self.model_save_path, f"model_epoch_{epoch}.pth")
            torch.save({
                    'epoch': epoch,
                    'model_weights': self.G_model.state_dict(),  
                    'optimizer_state_dict': self.opt_G.state_dict(),
                    'val_value': val_avg_value,
                }, G_save_path)
            print(f"Saved checkpoint: {G_save_path}")
        return best_val_value

    def train_epoch(self, **kwargs):
        """Logic for a single training epoch. Returns average loss."""
        epoch = kwargs.get('epoch')
        
        G_epoch_loss = 0

        G_per_criterio_loss = {} # voxel wise metrics dict

        self.G_model.train()

        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.config['num_epochs']}")

        for idx, batch_dict in enumerate(pbar):
            # batch_dict contains:
                # image (volume normalized)
                # gt (the real ground truth binary)
                # the roi_mask (with the region to ignore 0 and to consider 1)
                # weighted map for the bridge weight loss
                # segmentation logits (pre-computed to save time!)
            input_image = batch_dict['image'].to(self.config['device'])
            ground_truth = batch_dict['gt'].to(self.config['device'])
            roi_mask = batch_dict['roi_mask'].to(self.config['device'])
            bridge_weight_map = batch_dict['bridge_weight_map'].to(self.config['device'])
            pred_seg_logits = batch_dict['pred_seg_logits'].to(self.config['device']) # TODO change to pred_seg_logits

            self.opt_G.zero_grad()
            
            with autocast(device_type=self.config['device']):
                input_G = torch.cat([input_image, pred_seg_logits], dim=1)
                refined_logits = self.G_model(input_G)
                refined_probs_for_g = torch.sigmoid(refined_logits) 
                
                # Segmentation Loss (Ground Truth Accuracy)
                # Compare refined_logits directly to real_gt_mask
                voxel_wise_loss, G_losses_dict = self.G_voxel_criterion([refined_logits], [ground_truth], roi_mask=[roi_mask], bridge_weight_map=bridge_weight_map) 
                
                # loss_seg dominates (weight 100), loss_adv refines
                loss_g_total = voxel_wise_loss


            # Backward G
            self.G_scaler.scale(loss_g_total).backward()
            self.G_scaler.step(self.opt_G)
            
            # Update Scaler once per batch
            self.G_scaler.update()
            
            
            ##### Handle loss graphs #####
            # Overall loss
            G_epoch_loss += loss_g_total.item()
            
            # adding all individual metrics to the G dict
            for G_criterio_name in G_losses_dict.keys():
                if G_criterio_name in G_per_criterio_loss:
                    G_per_criterio_loss[G_criterio_name] += G_losses_dict[G_criterio_name]
                else:
                    G_per_criterio_loss[G_criterio_name] = G_losses_dict[G_criterio_name]
            
            # Update status bar
            pbar.set_postfix({
                "G_Loss": loss_g_total.item()
            })
        
        if epoch%10 == 0:
            # Save a prediction
            self.save_vol(refined_probs_for_g, join(self.preds_path, f"epoch_{epoch}_pred_train.nii.gz"))
            self.save_vol(input_image, join(self.preds_path, f"epoch_{epoch}_input_train.nii.gz"))
            self.save_vol(pred_seg_logits, join(self.preds_path, f"epoch_{epoch}_input_pred_seg_logits_train.nii.gz"))
            self.save_vol(ground_truth, join(self.preds_path, f"epoch_{epoch}_gt_train.nii.gz"))

        
        G_train_avg_loss = G_epoch_loss / len(self.train_loader)
        print(f"Epoch {epoch} Finished. Avg Loss: {G_train_avg_loss:.6f}")
        
        # This will replace each element in the dict with the mean
        for criterio_name in G_losses_dict.keys():
            G_per_criterio_loss[criterio_name] = G_per_criterio_loss[criterio_name] / len(self.train_loader)

        return G_train_avg_loss, G_per_criterio_loss
    
    def val(self, **kwargs):
        """Logic for evaluation. Returns a dictionary of metrics."""
        epoch = kwargs.get('epoch')
        self.G_model.eval()
        
        # General DSC validation value for quality controll
        val_value_sum = 0
        epoch_val_pixel_loss = 0
        adv_fake = 0
        adv_real = 0
        # Add the per criterio val loss for checking overfitting
        per_criterio_val_loss = {}
        for val_criterio_name in self.config['criterion']:
            per_criterio_val_loss[f"val_{val_criterio_name}"] = 0

        pbar = tqdm(self.val_loader, desc=f"Val epoch {epoch}/{self.config['num_epochs']}")
        for idx, batch_dict in enumerate(pbar):
            input_image = batch_dict['image'].to(self.config['device'])
            ground_truth = batch_dict['gt'].to(self.config['device']) 
            # Create the mask of the region to compute the loss
            roi_mask = batch_dict['roi_mask'].to(self.config['device'])
            bridge_weight_map = batch_dict['bridge_weight_map'].to(self.config['device'])
            pred_seg_logits = batch_dict['pred_seg_logits'].to(self.config['device'])

            with torch.no_grad():
                input_G = torch.cat([input_image, pred_seg_logits], dim=1)
                refined_logits = self.G_model(input_G)
                # Calculate DSC (Compare Prediction vs. GT)
                val_value = self.val_metric(pred=refined_logits, target=ground_truth, roi_mask=roi_mask)
                # Also compute val losses for logging (no deep supervision here)
                val_loss, val_losses_dict = self.G_voxel_criterion([refined_logits], [ground_truth], roi_mask=[roi_mask], bridge_weight_map=bridge_weight_map, deep_supervision_weights=[1.0]) 
                # commented to avoid overwhelming 
                #self.wandb_run.log({"val_value": val_value.item()})

            val_value_sum += val_value # val metric (DSC)
            epoch_val_pixel_loss += val_loss.item() # all losses function used for training (except Adv)
            for val_criterio_name in val_losses_dict.keys():
                per_criterio_val_loss[f"val_{val_criterio_name}"] += val_losses_dict[f"{val_criterio_name}"]
            pbar.set_postfix({"DSC": val_value})

        if epoch%10 == 0:
            pred_save = sigmoid(refined_logits)
            pred_save[pred_save>0.5] = 1.0
            pred_save[pred_save<=0.5] = 0.0
            self.save_vol(refined_logits, join(self.preds_path, f"epoch_{epoch}_logits_val.nii.gz"))
            self.save_vol(pred_save, join(self.preds_path, f"epoch_{epoch}_pred_val.nii.gz"))
            self.save_vol(input_image, join(self.preds_path, f"epoch_{epoch}_input_val.nii.gz"))
            self.save_vol(pred_seg_logits, join(self.preds_path, f"epoch_{epoch}_input_pred_seg_logits_val.nii.gz"))
            self.save_vol(ground_truth, join(self.preds_path, f"epoch_{epoch}_gt_val.nii.gz"))

        # computing mean of metrics
        val_avg_value = val_value_sum / len(self.val_loader)
        val_avg_pixel_loss = epoch_val_pixel_loss / len(self.val_loader)

        for val_criterio_name in val_losses_dict.keys():
            per_criterio_val_loss[f"val_{val_criterio_name}"] = per_criterio_val_loss[f"val_{val_criterio_name}"] / len(self.val_loader)
        print(f"Epoch {epoch} with validation avg DSC: {val_avg_value:.6f} | avg Loss: {val_avg_pixel_loss:.6f}")
        return val_avg_value, val_avg_pixel_loss, per_criterio_val_loss 
    
    def train_loop(self, **kwargs):
        """Standardized training loop."""
        best_val_value = self.val_value
        
        # Make sure all weights are trainable
        for param in self.G_model.parameters():
            param.requires_grad = True

        for self.epoch in range(self.start_epoch, self.config['num_epochs'] + 1):
            # Train one epoch
            G_train_avg_loss, G_per_criterio_loss = self.train_epoch(
                epoch=self.epoch
            )
            # Perform evaluation 
            val_avg_value, val_avg_pixel_loss, per_criterio_val_loss  = self.val(
                epoch=self.epoch
            )

            # Save in wandb
            log_train_data = {
                    "epoch_train": self.epoch,
                    "train_avg_loss": G_train_avg_loss,
                    "val_Dice": val_avg_value,
                    "val_avg_pixel_loss": val_avg_pixel_loss,
                    "G_lr": self.opt_G.param_groups[0]['lr'],
                    "D_lr": self.opt_D.param_groups[0]['lr']  
                }

            for criterio_name in G_per_criterio_loss.keys():
                log_train_data[criterio_name] = G_per_criterio_loss[criterio_name]
                if criterio_name.endswith("_fullres"):
                    print(f"_fullres is still in the loss function! It should not!")


            for val_criterio_name in per_criterio_val_loss.keys():
                log_train_data[val_criterio_name] = per_criterio_val_loss[val_criterio_name]

            self.wandb_run.log(
                log_train_data
            )

            # Checking if saving 
            best_val_value = self.saving_logic(
                best_val_value=best_val_value, 
                val_avg_value=val_avg_value, 
                epoch=self.epoch
            )

            # Applying learning rate Cosine Annealing
            self.G_cosAnnealLR.step() 

In [None]:
CONFIG_FILE = '../configs/post_process_GANs.json'
with open(CONFIG_FILE, "r") as f:
    config_content = json.load(f)
GANs_train_object = postprocessGANs(config_content)
GANs_train_object.train_loop()

In [None]:
# TODO check if the D_loss is within the value expected