# Develop a light GAN for post processing
* The input must be the biggest size in the dataset
* Both Generator and Discriminator must fit in the GPU (48GB VRAM)


### Check biggest (384,384,384 although there is only 1 case with this shape...)

In [None]:
# Find the bigest shape
import os
import tifffile as tiff
import numpy as np
import nibabel as nib
all_shapes = []
root_dataset = "/home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop"
for file_name in os.listdir(root_dataset):
    if file_name.endswith('.nii.gz'):
        file_path = os.path.join(root_dataset, file_name)
        img_array = nib.load(file_path).get_fdata()
        all_shapes.append(img_array.shape)
        #if img_array.shape[0]>320:
        print(img_array.shape)

In [None]:
x_all_shapes = [shape_element[0] for shape_element in all_shapes]
y_all_shapes = [shape_element[1] for shape_element in all_shapes]
z_all_shapes = [shape_element[2] for shape_element in all_shapes]

print(f"Biggest x: {max(x_all_shapes)}")
print(f"Biggest y: {max(y_all_shapes)}")
print(f"Biggest z: {max(z_all_shapes)}")

### Create MedNeXt network

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint # Crucial for 48GB VRAM optimization

# --- 1. Modern Components ---

class LayerNorm3d(nn.Module):
    """ Modern LayerNorm that supports Channels First (N, C, D, H, W) """
    def __init__(self, num_channels, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x):
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        return self.weight[:, None, None, None] * x + self.bias[:, None, None, None]

class GRN(nn.Module):
    """ 
    Global Response Normalization (The V2 Secret Weapon)
    Prevents "dead channels" in sparse data by normalizing feature competition.
    """
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, 1))
        self.eps = eps

    def forward(self, x):
        # Input: [N, D, H, W, C] (Channels Last for efficiency)
        Gx = torch.norm(x, p=2, dim=(1,2,3), keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + self.eps)
        return self.gamma * (x * Nx) + self.beta + x

class DropPath(nn.Module):
    """ Stochastic Depth: Randomly drops paths to prevent overfitting in deep models """
    def __init__(self, drop_prob=0.0):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_() 
        output = x.div(keep_prob) * random_tensor
        return output

# --- 2. The Block ---

class ConvNeXtV2Block(nn.Module):
    def __init__(self, dim, dilation=1, drop_path=0.0):
        super().__init__()
        
        # A. Depthwise Conv (Spatial Context)
        # padding=dilation ensures we maintain 1:1 resolution (No shrinking!)
        self.dwconv = nn.Conv3d(dim, dim, kernel_size=3, padding=dilation, 
                                groups=dim, dilation=dilation) 
        self.norm = LayerNorm3d(dim, eps=1e-6)
        
        # B. Pointwise MLP (Channel Reasoning)
        # Inverted Bottleneck: Expand 1 -> 4 -> 1
        self.pwconv1 = nn.Linear(dim, 4 * dim) 
        self.act = nn.GELU()
        self.grn = GRN(4 * dim) # The V2 upgrade
        self.pwconv2 = nn.Linear(4 * dim, dim)
        
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        
        # 1. Spatial Processing
        x = self.dwconv(x)
        x = self.norm(x)
        
        # 2. Permute to Channels Last (N, D, H, W, C) for Linear Layers
        x = x.permute(0, 2, 3, 4, 1) 
        
        # 3. MLP with GRN
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.pwconv2(x)
        
        # 4. Restore Layout
        x = x.permute(0, 4, 1, 2, 3)

        # 5. Residual Connection
        x = input + self.drop_path(x)
        return x

# --- 3. The Network ---

class RefinerNetwork(nn.Module):
    def __init__(self, in_channels=2, base_dim=96, depth=18, use_checkpoint=True):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        
        # 1. Stem: Project [Scan + Mask] into Feature Space
        self.stem = nn.Sequential(
            nn.Conv3d(in_channels, base_dim, kernel_size=3, padding=1),
            LayerNorm3d(base_dim)
        )
        
        # 2. Deep Body with "Sawtooth" Dilation
        # This pattern expands the receptive field to >60 pixels to see across gaps
        # Pattern: 1, 2, 4, 8, 4, 2 ...
        base_dilations = [1, 2, 4, 8, 4, 2]
        self.blocks = nn.ModuleList()
        
        # Stochastic depth decay (linear)
        dp_rates = [x.item() for x in torch.linspace(0, 0.2, depth)] 
        
        for i in range(depth):
            d = base_dilations[i % len(base_dilations)]
            self.blocks.append(ConvNeXtV2Block(base_dim, dilation=d, drop_path=dp_rates[i]))
            
        # 3. Head: Project back to 1 channel mask
        self.head = nn.Conv3d(base_dim, 1, kernel_size=1)

    def forward(self, x):
        # x shape: [Batch, 2, D, H, W]
        x = self.stem(x)
        
        for block in self.blocks:
            if self.use_checkpoint and self.training:
                # Trades compute for VRAM, allowing Depth=18 on 48GB GPU
                x = checkpoint(block, x, use_reentrant=False)
            else:
                x = block(x)
                
        return self.head(x)

# --- 4. Instantiation ---
def get_refiner():
    # in_channels=2: Channel 0 is Original Scan, Channel 1 is Broken Prediction
    # proposed refinerNetwork = RefinerNetwork(in_channels=2, base_dim=96, depth=18, use_checkpoint=True)
    refinerNetwork = RefinerNetwork(in_channels=2, base_dim=16, depth=18, use_checkpoint=True)
    return refinerNetwork

In [2]:
def profile_model(model, input_shape=(1, 2, 128, 128, 128), device='cuda', train_mode=True):
    """
    Prints model summary, measures VRAM usage and Speed.
    
    Args:
        model: PyTorch model
        input_shape: Tuple of input shape (N, C, D, H, W)
        device: 'cuda' or 'cpu'
        train_mode: If True, measures backward pass memory/time too.
    """
    print("="*60)
    print(f"Running Benchmark on: {torch.cuda.get_device_name(0)}")
    print(f"Config: {input_shape} | Training Mode: {train_mode}")
    print("="*60)

    # 1. Print Model Structure (using torchinfo if available)
    try:
        from torchinfo import summary
        print(summary(model, input_size=input_shape, depth=3, 
                      col_names=["input_size", "output_size", "num_params", "kernel_size"]))
    except ImportError:
        print("[Info] Install 'torchinfo' for a prettier model summary.")
        print(f"Total Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    # 2. Setup
    model.to(device)
    if train_mode:
        model.train()
    else:
        model.eval()

    # Create dummy input
    dummy_input = torch.randn(input_shape).to(device)
    dummy_target = torch.randn((input_shape[0], 1, input_shape[2], input_shape[3], input_shape[4])).to(device)
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    # 3. Warmup (Stabilize GPU clock speeds)
    print("\nWarming up GPU...")
    for _ in range(5):
        optimizer.zero_grad()
        out = model(dummy_input)
        if train_mode:
            loss = criterion(out, dummy_target)
            loss.backward()
            optimizer.step()
    torch.cuda.synchronize()

    # 4. Measure VRAM (Peak Allocation)
    print("Measuring VRAM usage...")
    torch.cuda.reset_peak_memory_stats()
    optimizer.zero_grad()
    
    # Run one step
    start_vram = torch.cuda.memory_allocated()
    out = model(dummy_input)
    if train_mode:
        loss = criterion(out, dummy_target)
        loss.backward()
    
    peak_memory = torch.cuda.max_memory_allocated()
    memory_gb = peak_memory / (1024 ** 3)
    
    print(f"Peak VRAM Usage: {memory_gb:.2f} GB")

    # 5. Measure Speed (Average over 50 steps)
    print("Measuring Speed (50 steps)...")
    timings = []
    scaler = torch.cuda.amp.GradScaler() # Optional: simulate AMP overhead if needed
    
    # Using CUDA Events for precise timing
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    for _ in range(50):
        optimizer.zero_grad()
        start_event.record()
        
        out = model(dummy_input)
        if train_mode:
            loss = criterion(out, dummy_target)
            loss.backward()
            optimizer.step()
            
        end_event.record()
        torch.cuda.synchronize()
        timings.append(start_event.elapsed_time(end_event)) # Returns milliseconds

    avg_time_ms = sum(timings) / len(timings)
    print(f"Average Time per Step: {avg_time_ms:.2f} ms")
    print("="*60)

# --- 3. EXECUTION ---

if __name__ == "__main__":
    if not torch.cuda.is_available():
        print("Error: CUDA is not available. This benchmark requires a GPU.")
    else:
        # Instantiate your specific configuration
        # NOTE: use_checkpoint=True saves memory but adds ~30% time overhead
        model = RefinerNetwork(in_channels=2, base_dim=16, depth=18, use_checkpoint=True)
        
        # Test with a standard Vesuvius crop size
        # Change batch_size (1st dim) to see how much your 48GB card can handle!
        profile_model(model, input_shape=(1, 2, 128, 128, 128), train_mode=True)

Running Benchmark on: NVIDIA GeForce RTX 4060 Ti
Config: (1, 2, 128, 128, 128) | Training Mode: True
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape
RefinerNetwork                           [1, 2, 128, 128, 128]     [1, 1, 128, 128, 128]     --                        --
├─Sequential: 1-1                        [1, 2, 128, 128, 128]     [1, 16, 128, 128, 128]    --                        --
│    └─Conv3d: 2-1                       [1, 2, 128, 128, 128]     [1, 16, 128, 128, 128]    880                       [3, 3, 3]
│    └─LayerNorm3d: 2-2                  [1, 16, 128, 128, 128]    [1, 16, 128, 128, 128]    32                        --
├─ModuleList: 1-2                        --                        --                        --                        --
│    └─ConvNeXtV2Block: 2-3              [1, 16, 128, 128, 128]    [1, 16, 128, 128, 128]    --                        --
│    │    └─Conv3d: 3-1     

  scaler = torch.cuda.amp.GradScaler() # Optional: simulate AMP overhead if needed


Average Time per Step: 4071.90 ms


### Building the training class

In [2]:

import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
import torch
import json
import sys
from os.path import join
sys.path.append("../utils")
from main_train_class import main_train_STU_Net
from tqdm import tqdm
from torch.nn.functional import sigmoid, binary_cross_entropy_with_logits
# Standard Library Imports
from os.path import join
import sys

import json

# Third-Party Library Imports
import torch
import torch.optim as optim
from torch.nn.functional import sigmoid
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import autocast, GradScaler
from torch.nn import MSELoss
# MONAI Specific Imports
import monai
from monai.data import CacheDataset

from monai.transforms import (
    Compose,
    CopyItemsd,
    LoadImaged, 
    ScaleIntensityRanged, 
    ResizeWithPadOrCropd,
    EnsureTyped,
    EnsureChannelFirstd
)
from mask_utils import GetROIMaskdd, GetBinaryLabeld
# Local Project Imports

class postprocessConvNeXt(main_train_STU_Net):
    def __init__(self, config):
        # TODO predict and save logits from the segmentation network
        self.config = config
        self.labda_seg = self.config['labda_seg']
        self.G_model = self._build_models()
        self.train_loader = self._set_train_dataloader() 
        self.val_loader = self._set_val_dataloader() 
        self.opt_G = self._set_optimizers()
        self.wandb_run = self._set_wandb_checkpoint() # Heritage
        # Expects lists of predictions # TODO put predictions inside of a list
        self.G_voxel_criterion = self._set_train_criterion() # Heritage
        
        self.val_metric = self._set_val_metric() # Heritage
        self.G_cosAnnealLR = self._set_scheduler()

        # set scaler for mix precision 
        self.G_scaler = GradScaler()
        
        # check resume
        self._resume()

    def _build_models(self):
        # TODO change network to base_dim=96 depth=18
        model = RefinerNetwork(in_channels=2, base_dim=4, depth=9, use_checkpoint=True).to(self.config['device'])
        
        return model
    
    def _resume(self):
        if self.config.get('resume'):
            G_checkpoint = torch.load(self.config['resume'], map_location="cpu", weights_only=False) 
            G_model_weights = G_checkpoint['model_weights']  
            self.G_model.load_state_dict(G_model_weights, strict=True)
            self.G_model = self.G_model.to(self.config['device'])
            # optimizer load
            self.opt_G.load_state_dict(G_checkpoint['optimizer_state_dict'])
            # parameters
            self.start_epoch = G_checkpoint['epoch'] + 1 # To continue to the next epoch instead of repeating  
            self.val_value = G_checkpoint['val_value']
        else:
            self.start_epoch = 0
            self.val_value = 0

    def _set_train_dataloader(self):
        """ Getting the list of cases for training and loading using MONAI (all into memory)"""
        data_list = []
        # TODO: change back to this data list
        #with open(self.config['data_split_json'], "r") as f:
        #    split = json.load(f)
        #train_cases = split["train"]
        #
        train_cases = ["2290837.nii.gz"]
        for train_case in train_cases:
            complete_data_dict = {}
            complete_data_dict["image"] = join(self.config['vol_data_path'], train_case)
            complete_data_dict["gt"] = join(self.config['label_data_path'], train_case)
            complete_data_dict["bridge_weight_map"] = join(self.config['bridge_weight_map_path'], train_case)
            # TODO in the json the pred_seg_logits needs to be changed!
            complete_data_dict["pred_seg_logits"] = join(self.config['pred_seg_logits'], train_case)
            data_list.append(complete_data_dict)
            
            if self.config['debug']:
                for i in range(30):
                    data_list.append(complete_data_dict)
                print(f"training using case: {data_list[0]}")
                break  # repeat 30 cases for debug mode

        print(f"Train cases: {len(train_cases)}")
        print(f"Some examples:")
        print(train_cases[:5])

        transforms_list = [   
                # Load image 
                LoadImaged(keys=["image", 'gt', 'bridge_weight_map', 'pred_seg_logits']),
                EnsureChannelFirstd(keys=["image", 'gt', 'bridge_weight_map', 'pred_seg_logits']),

                # Normalize uint8 input
                ScaleIntensityRanged(keys=["image"], a_min=0, a_max=255, b_min=0, b_max=1, clip=True),

                # Create a ROI mask for cropping 
                GetROIMaskdd(keys=["gt"], ignore_mask_value=2, new_key_names=["roi_mask"]),

                # Cropping or padding if bigger or smaller (expected to be all equall of smaller)
                ResizeWithPadOrCropd(keys=["image", 'gt', 'roi_mask', 'bridge_weight_map', 'pred_seg_logits'], spatial_size=self.config['patch_size'], mode="minimum"),
                GetBinaryLabeld(keys=["gt"], ignore_mask_value=2),
                EnsureTyped(keys=["image", "gt", "roi_mask", "bridge_weight_map", 'pred_seg_logits'], track_meta=False)
        ]

        transforms = Compose(transforms_list)
        
        print("Initializing Dataset...")
        train_ds = CacheDataset(
            data=data_list, 
            transform=transforms, 
            cache_rate=self.config['train_cache_rate'],  
            num_workers=self.config['num_workers'], 
            progress=True
        )

        print("Initializing Train DataLoader...")
        train_loader = monai.data.DataLoader(
            train_ds, 
            batch_size=self.config['batch_size'], 
            num_workers=self.config['num_workers'],
            shuffle=True,      
            pin_memory=True
        )
        return train_loader
    
    def _set_val_dataloader(self):
        data_list = []
        with open(self.config['data_split_json'], "r") as f:
            split = json.load(f)

        val_cases = split["val"]

        if self.config['debug']: 
            print(f"Debug mode: Using training cases for validation dataloader")
            # train_cases = split["train"] # TODO uncomment
            train_cases = ["2290837.nii.gz"] # TODO remove
            for train_case in train_cases:
                complete_data_dict = {}
                complete_data_dict["image"] = join(self.config['vol_data_path'], train_case)
                complete_data_dict["gt"] = join(self.config['label_data_path'], train_case)
                complete_data_dict["bridge_weight_map"] = join(self.config['bridge_weight_map_path'], train_case)
                complete_data_dict["pred_seg_logits"] = join(self.config['pred_seg_logits'], train_case)
                data_list.append(complete_data_dict)
                print(f"Validation using case: {train_case}")
                break  # The same training sample for validation in debug mode
        else:
            for val_case in val_cases:
                complete_data_dict = {}
                complete_data_dict["image"] = join(self.config['vol_data_path'], val_case)
                complete_data_dict["gt"] = join(self.config['label_data_path'], val_case)
                complete_data_dict["bridge_weight_map"] = join(self.config['bridge_weight_map_path'], val_case)
                complete_data_dict["pred_seg_logits"] = join(self.config['pred_seg_logits'], val_case)
                data_list.append(complete_data_dict)

        print(f"Val cases: {len(val_cases)}")
        print(f"Some examples:")
        print(val_cases[:5])

        transforms = Compose(
            [   
                # Load image 
                LoadImaged(keys=["image", 'gt', 'bridge_weight_map', 'pred_seg_logits']),
                EnsureChannelFirstd(keys=["image", 'gt', 'bridge_weight_map', 'pred_seg_logits']),
                # Normalize uint8 input
                ScaleIntensityRanged(keys=["image"], a_min=0, a_max=255, b_min=0, b_max=1, clip=True),
                # Create a ROI mask for cropping 
                GetROIMaskdd(keys=["gt"], ignore_mask_value=2, new_key_names=["roi_mask"]),
                # Get random patches
                ResizeWithPadOrCropd(keys=["image", "gt", "roi_mask", 'bridge_weight_map', 'pred_seg_logits'], spatial_size=self.config['patch_size'], mode="minimum"),
                GetBinaryLabeld(keys=["gt"], ignore_mask_value=2),
                EnsureTyped(keys=["image", "gt", "roi_mask", 'bridge_weight_map', 'pred_seg_logits'], track_meta=False)
            ]
        )

        print("Initializing Dataset...")
        val_ds = CacheDataset(
            data=data_list, 
            transform=transforms, 
            cache_rate=self.config['val_cache_rate'],  
            num_workers=self.config['num_workers'], 
            progress=True
        )
        
        print("Initializing Val DataLoader...")
        val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=self.config['num_workers'])
        return val_loader

    def _set_optimizers(self):
        """Define the optimizer (e.g., Adam, SGD)."""
        opt_g = optim.AdamW(self.G_model.parameters(), lr=self.config['learning_rate'])
        return opt_g
    
    def _set_scheduler(self):
        """Define learning rate scheduler."""
        # If resuming, last_epoch should be start_epoch - 1
        last_epoch = self.config.get('resume_epoch', 0) - 1 if self.config.get('resume') else -1
        G_cosAnnealLR = CosineAnnealingLR(self.opt_G, self.config['num_epochs'], eta_min=self.config['learning_rate']/10, last_epoch=last_epoch)
        return G_cosAnnealLR
    
    def saving_logic(self, best_val_value, val_avg_value, epoch):
        """ Logic to save the best model and periodic checkpoints """

        if best_val_value < val_avg_value: 
            best_val_value = val_avg_value
            G_save_path = join(self.model_save_path, f"model_best.pth")
            torch.save({
                    'epoch': epoch,
                    'model_weights': self.G_model.state_dict(),  
                    'optimizer_state_dict': self.opt_G.state_dict(),
                    'val_value': val_avg_value,
                }, G_save_path)
            print(f"Saved checkpoint: {G_save_path}")
        
        # Save Checkpoint
        if epoch % 10 == 0: 
            G_save_path = join(self.model_save_path, f"model_epoch_{epoch}.pth")
            torch.save({
                    'epoch': epoch,
                    'model_weights': self.G_model.state_dict(),  
                    'optimizer_state_dict': self.opt_G.state_dict(),
                    'val_value': val_avg_value,
                }, G_save_path)
            print(f"Saved checkpoint: {G_save_path}")
        return best_val_value

    def train_epoch(self, **kwargs):
        """Logic for a single training epoch. Returns average loss."""
        epoch = kwargs.get('epoch')
        
        G_epoch_loss = 0

        G_per_criterio_loss = {} # voxel wise metrics dict

        self.G_model.train()

        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.config['num_epochs']}")

        for idx, batch_dict in enumerate(pbar):
            # batch_dict contains:
                # image (volume normalized)
                # gt (the real ground truth binary)
                # the roi_mask (with the region to ignore 0 and to consider 1)
                # weighted map for the bridge weight loss
                # segmentation logits (pre-computed to save time!)
            input_image = batch_dict['image'].to(self.config['device'])
            ground_truth = batch_dict['gt'].to(self.config['device'])
            roi_mask = batch_dict['roi_mask'].to(self.config['device'])
            bridge_weight_map = batch_dict['bridge_weight_map'].to(self.config['device'])
            pred_seg_logits = batch_dict['pred_seg_logits'].to(self.config['device']) # TODO change to pred_seg_logits

            self.opt_G.zero_grad()
            
            with autocast(device_type=self.config['device']):
                input_G = torch.cat([input_image, pred_seg_logits], dim=1)
                refined_logits = self.G_model(input_G)
                refined_probs_for_g = torch.sigmoid(refined_logits) 
                
                # Segmentation Loss (Ground Truth Accuracy)
                # Compare refined_logits directly to real_gt_mask
                voxel_wise_loss, G_losses_dict = self.G_voxel_criterion([refined_logits], [ground_truth], roi_mask=[roi_mask], bridge_weight_map=bridge_weight_map) 
                
                # loss_seg dominates (weight 100), loss_adv refines
                loss_g_total = voxel_wise_loss


            # Backward G
            self.G_scaler.scale(loss_g_total).backward()
            self.G_scaler.step(self.opt_G)
            
            # Update Scaler once per batch
            self.G_scaler.update()
            
            
            ##### Handle loss graphs #####
            # Overall loss
            G_epoch_loss += loss_g_total.item()
            
            # adding all individual metrics to the G dict
            for G_criterio_name in G_losses_dict.keys():
                if G_criterio_name in G_per_criterio_loss:
                    G_per_criterio_loss[G_criterio_name] += G_losses_dict[G_criterio_name]
                else:
                    G_per_criterio_loss[G_criterio_name] = G_losses_dict[G_criterio_name]
            
            # Update status bar
            pbar.set_postfix({
                "G_Loss": loss_g_total.item()
            })
        
        if epoch%10 == 0:
            # Save a prediction
            self.save_vol(refined_probs_for_g, join(self.preds_path, f"epoch_{epoch}_pred_train.nii.gz"))
            self.save_vol(input_image, join(self.preds_path, f"epoch_{epoch}_input_train.nii.gz"))
            self.save_vol(pred_seg_logits, join(self.preds_path, f"epoch_{epoch}_input_pred_seg_logits_train.nii.gz"))
            self.save_vol(ground_truth, join(self.preds_path, f"epoch_{epoch}_gt_train.nii.gz"))

        
        G_train_avg_loss = G_epoch_loss / len(self.train_loader)
        print(f"Epoch {epoch} Finished. Avg Loss: {G_train_avg_loss:.6f}")
        
        # This will replace each element in the dict with the mean
        for criterio_name in G_losses_dict.keys():
            G_per_criterio_loss[criterio_name] = G_per_criterio_loss[criterio_name] / len(self.train_loader)

        return G_train_avg_loss, G_per_criterio_loss
    
    def val(self, **kwargs):
        """Logic for evaluation. Returns a dictionary of metrics."""
        epoch = kwargs.get('epoch')
        self.G_model.eval()
        
        # General DSC validation value for quality controll
        val_value_sum = 0
        epoch_val_pixel_loss = 0
        adv_fake = 0
        adv_real = 0
        # Add the per criterio val loss for checking overfitting
        per_criterio_val_loss = {}
        for val_criterio_name in self.config['criterion']:
            per_criterio_val_loss[f"val_{val_criterio_name}"] = 0

        pbar = tqdm(self.val_loader, desc=f"Val epoch {epoch}/{self.config['num_epochs']}")
        for idx, batch_dict in enumerate(pbar):
            input_image = batch_dict['image'].to(self.config['device'])
            ground_truth = batch_dict['gt'].to(self.config['device']) 
            # Create the mask of the region to compute the loss
            roi_mask = batch_dict['roi_mask'].to(self.config['device'])
            bridge_weight_map = batch_dict['bridge_weight_map'].to(self.config['device'])
            pred_seg_logits = batch_dict['pred_seg_logits'].to(self.config['device'])

            with torch.no_grad():
                input_G = torch.cat([input_image, pred_seg_logits], dim=1)
                refined_logits = self.G_model(input_G)
                # Calculate DSC (Compare Prediction vs. GT)
                val_value = self.val_metric(pred=refined_logits, target=ground_truth, roi_mask=roi_mask)
                # Also compute val losses for logging (no deep supervision here)
                val_loss, val_losses_dict = self.G_voxel_criterion([refined_logits], [ground_truth], roi_mask=[roi_mask], bridge_weight_map=bridge_weight_map, deep_supervision_weights=[1.0]) 
                # commented to avoid overwhelming 
                #self.wandb_run.log({"val_value": val_value.item()})

            val_value_sum += val_value # val metric (DSC)
            epoch_val_pixel_loss += val_loss.item() # all losses function used for training (except Adv)
            for val_criterio_name in val_losses_dict.keys():
                per_criterio_val_loss[f"val_{val_criterio_name}"] += val_losses_dict[f"{val_criterio_name}"]
            pbar.set_postfix({"DSC": val_value})

        if epoch%10 == 0:
            pred_save = sigmoid(refined_logits)
            pred_save[pred_save>0.5] = 1.0
            pred_save[pred_save<=0.5] = 0.0
            self.save_vol(refined_logits, join(self.preds_path, f"epoch_{epoch}_logits_val.nii.gz"))
            self.save_vol(pred_save, join(self.preds_path, f"epoch_{epoch}_pred_val.nii.gz"))
            self.save_vol(input_image, join(self.preds_path, f"epoch_{epoch}_input_val.nii.gz"))
            self.save_vol(pred_seg_logits, join(self.preds_path, f"epoch_{epoch}_input_pred_seg_logits_val.nii.gz"))
            self.save_vol(ground_truth, join(self.preds_path, f"epoch_{epoch}_gt_val.nii.gz"))

        # computing mean of metrics
        val_avg_value = val_value_sum / len(self.val_loader)
        val_avg_pixel_loss = epoch_val_pixel_loss / len(self.val_loader)

        for val_criterio_name in val_losses_dict.keys():
            per_criterio_val_loss[f"val_{val_criterio_name}"] = per_criterio_val_loss[f"val_{val_criterio_name}"] / len(self.val_loader)
        print(f"Epoch {epoch} with validation avg DSC: {val_avg_value:.6f} | avg Loss: {val_avg_pixel_loss:.6f}")
        return val_avg_value, val_avg_pixel_loss, per_criterio_val_loss 
    
    def train_loop(self, **kwargs):
        """Standardized training loop."""
        best_val_value = self.val_value
        
        # Make sure all weights are trainable
        for param in self.G_model.parameters():
            param.requires_grad = True

        for self.epoch in range(self.start_epoch, self.config['num_epochs'] + 1):
            # Train one epoch
            G_train_avg_loss, G_per_criterio_loss = self.train_epoch(
                epoch=self.epoch
            )
            # Perform evaluation 
            val_avg_value, val_avg_pixel_loss, per_criterio_val_loss  = self.val(
                epoch=self.epoch
            )

            # Save in wandb
            log_train_data = {
                    "epoch_train": self.epoch,
                    "train_avg_loss": G_train_avg_loss,
                    "val_Dice": val_avg_value,
                    "val_avg_pixel_loss": val_avg_pixel_loss,
                    "G_lr": self.opt_G.param_groups[0]['lr'],
                    "D_lr": self.opt_D.param_groups[0]['lr']  
                }

            for criterio_name in G_per_criterio_loss.keys():
                log_train_data[criterio_name] = G_per_criterio_loss[criterio_name]
                if criterio_name.endswith("_fullres"):
                    print(f"_fullres is still in the loss function! It should not!")


            for val_criterio_name in per_criterio_val_loss.keys():
                log_train_data[val_criterio_name] = per_criterio_val_loss[val_criterio_name]

            self.wandb_run.log(
                log_train_data
            )

            # Checking if saving 
            best_val_value = self.saving_logic(
                best_val_value=best_val_value, 
                val_avg_value=val_avg_value, 
                epoch=self.epoch
            )

            # Applying learning rate Cosine Annealing
            self.G_cosAnnealLR.step() 

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [3]:
CONFIG_FILE = '../configs/post_process_ConvNeXt.json'
with open(CONFIG_FILE, "r") as f:
    config_content = json.load(f)
GANs_train_object = postprocessConvNeXt(config_content)
GANs_train_object.train_loop()

training using case: {'image': '/home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/DataSet/small_train/images/2290837.nii.gz', 'gt': '/home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/DataSet/small_train/labels/2290837.nii.gz', 'bridge_weight_map': '/home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/DataSet/small_train/train_bridge_weight_map/2290837.nii.gz', 'pred_seg_logits': '/home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/DataSet/small_train/pred_seg_logits/2290837.nii.gz'}
Train cases: 1
Some examples:
['2290837.nii.gz']
Initializing Dataset...
Initializing Train DataLoader...
Debug mode: Using training cases for validation dataloader
Validation using case: 2290837.nii.gz
Val cases: 79
Some examples:
['693501383.nii.gz', '118041886.nii.gz', '571334887.nii.gz', '2536049117.nii.gz', '1127903126.nii.gz']
Initializing Dataset...


Loading dataset: 100%|██████████| 1/1 [00:02<00:00,  2.01s/it]


Initializing Val DataLoader...


[34m[1mwandb[0m: Currently logged in as: [33mshadowtwin[0m ([33mfaking_it[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0/30:   0%|          | 0/31 [00:05<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.95 GiB. GPU 0 has a total capacity of 15.57 GiB of which 1.37 GiB is free. Including non-PyTorch memory, this process has 13.63 GiB memory in use. Of the allocated memory 13.37 GiB is allocated by PyTorch, and 51.95 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# TODO check if the D_loss is within the value expected