In [None]:
"""
main.py - Main script for training Noise2Void with NAFNet
"""

import os
import torch
from torch.utils.data import DataLoader

from N2V_Vanilla.config import create_config, TrainingConfig
from N2V_Vanilla.models import create_nafnet
from N2V_Vanilla.losses import create_loss_function
from N2V_Vanilla.dataset import NoisyPatchDataset, ValidationDataset
from N2V_Vanilla.trainer import create_trainer
from N2V_Vanilla.utils import setup_reproducibility, print_model_info, create_directories

# ---------------------------
# === Main Training Function
# ---------------------------
def main(config: TrainingConfig):
    # Validate configuration
    config.validate_config()
    
    # Setup reproducibility
    setup_reproducibility(config.seed)
    
    # Get device
    device = torch.device(config.get_device())
    print(f"Using device: {device}")
    
    # Create output directories
    create_directories(config.output_dir)
    
    # Create datasets
    print("\nCreating datasets...")
    train_ds = NoisyPatchDataset(
        root_dir=config.train_dir,
        patch_size=config.patch_size,
        dataset_size=config.dataset_size
    )
    
    val_ds = ValidationDataset(
        gt_dir=config.val_gt_dir,
        noisy_dir=config.val_noisy_dir,
        normalize=False  # Normalization handled in utils
    ) if config.val_gt_dir and config.val_noisy_dir else None
    
    # Create dataloaders
    train_loader = DataLoader(
        train_ds,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory
    )
    
    val_loader = DataLoader(
        val_ds,
        batch_size=1,  # Process one GT pair at a time
        shuffle=False,
        num_workers=0,
        pin_memory=config.pin_memory
    ) if val_ds else None
    
    # Create model
    print("\nCreating NAFNet model...")
    model = create_nafnet(config.nafnet_config).to(device)
    print_model_info(model, config.nafnet_config)
    
    # Create loss function
    print(f"\nCreating {config.loss_type} loss function...")
    criterion = create_loss_function(
        loss_type=config.loss_type,
        device=device,
        **config.loss_weights
    )
    
    # Create optimizer
    optimizer_params = {
        'lr': config.learning_rate,
        'weight_decay': config.weight_decay,
        'betas': config.betas
    }
    if config.optimizer_type.lower() == "adam":
        optimizer = torch.optim.Adam(model.parameters(), **optimizer_params)
    elif config.optimizer_type.lower() == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(), **optimizer_params)
    else:
        raise ValueError(f"Unknown optimizer type: {config.optimizer_type}")
    
    # Create scheduler
    scheduler = None
    if config.scheduler_type.lower() == "step":
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=config.step_size,
            gamma=config.gamma
        )
    elif config.scheduler_type.lower() == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=int(config.epochs/5)
        )
    elif config.scheduler_type.lower() == "plateau":
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            factor=0.5,
            patience=5,
            verbose=True
        )
    else:
        print(f"Warning: Unknown scheduler type '{config.scheduler_type}', no scheduler will be used.")
    
    # Create trainer
    trainer_config = config.to_dict()
    trainer = create_trainer(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        config=trainer_config
    )
    
    # Optional: Resume from checkpoint
    # resume_path = "./training_output/checkpoints/checkpoint_epoch_XX.pth"
    # if os.path.exists(resume_path):
    #     trainer.resume_training(resume_path)
    
    # Start training
    trainer.train(
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=config.epochs
    )

if __name__ == "__main__":
    # Create configuration
    # Options: "default", "fast", "high_quality", "cpu"
    config = create_config("default")
    
    # Optional: Update paths if needed
    config.update_paths(
        train_dir=r"E:\PHD\phd_env\Proyectos\Denoising_challenge\Calcium\Data\Train\slices",
        val_gt_dir=r"E:\PHD\phd_env\Proyectos\Denoising_challenge\Calcium\Data\Val\GT\slices",
        val_noisy_dir=r"E:\PHD\phd_env\Proyectos\Denoising_challenge\Calcium\Data\Val\noisy\slices",
        output_dir="./training_output"
    )
    
    # Optional: Customize further
    config.epochs = 50
    config.batch_size = 32
    config.learning_rate = 5e-3
    config.model_size = 'medium'
    config.update_model_config(model_size='medium')
    config.loss_type = "advanced"
    config.loss_weights['ssim_weight'] = 0.3
    config.num_workers = 8
    config.scheduler_type.lower() == "step"
    
    main(config)

Using device: cuda

Creating datasets...
Found 6000 training images in E:\PHD\phd_env\Proyectos\Denoising_challenge\Calcium\Data\Train\slices
Found 1500 validation pairs

Creating NAFNet model...




MODEL INFORMATION
Architecture: NAFNet
Parameters: 1,128,161
Configuration:
  img_channel: 1
  width: 32
  middle_blk_num: 1
  enc_blk_nums: [1, 1, 1]
  dec_blk_nums: [1, 1, 1]

Creating advanced loss function...
Starting training for 50 epochs...

Epoch 1/50
----------------------------------------


Epoch 1 [Train]: 100%|████████████████████████████████████████████████| 625/625 [11:10<00:00,  1.07s/it, loss=0.034142]


Train Loss: 0.034142
Learning Rate: 5.00e-03


Epoch 1 [Val]: 100%|██████████████████████████████████████████| 20/20 [00:32<00:00,  1.61s/it, psnr=32.709, ssim=0.733]
  plt.tight_layout()


Val PSNR: 34.6217, Val SSIM: 0.7835
New best model saved! PSNR: 34.6217

Epoch 2/50
----------------------------------------


Epoch 2 [Train]: 100%|████████████████████████████████████████████████| 625/625 [11:10<00:00,  1.07s/it, loss=0.033721]


Train Loss: 0.033721
Learning Rate: 5.00e-03


Epoch 2 [Val]: 100%|██████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=32.703, ssim=0.724]


Val PSNR: 34.6703, Val SSIM: 0.7793
New best model saved! PSNR: 34.6703

Epoch 3/50
----------------------------------------


Epoch 3 [Train]: 100%|████████████████████████████████████████████████| 625/625 [11:09<00:00,  1.07s/it, loss=0.033807]


Train Loss: 0.033807
Learning Rate: 5.00e-03


Epoch 3 [Val]: 100%|██████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=32.864, ssim=0.737]


Val PSNR: 34.8256, Val SSIM: 0.7891
New best model saved! PSNR: 34.8256

Epoch 4/50
----------------------------------------


Epoch 4 [Train]: 100%|█████████████████████████████████████████████| 625/625 [11:08<00:00,  1.07s/it, loss=4371.637140]


Train Loss: 4371.637140
Learning Rate: 5.00e-03


Epoch 4 [Val]: 100%|██████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=32.602, ssim=0.746]


Val PSNR: 34.4736, Val SSIM: 0.7937

Epoch 5/50
----------------------------------------


Epoch 5 [Train]: 100%|████████████████████████████████████████████████| 625/625 [11:11<00:00,  1.07s/it, loss=0.033525]


Train Loss: 0.033525
Learning Rate: 5.00e-03


Epoch 5 [Val]: 100%|██████████████████████████████████████████| 20/20 [00:32<00:00,  1.60s/it, psnr=33.531, ssim=0.757]


Val PSNR: 35.5294, Val SSIM: 0.8083
New best model saved! PSNR: 35.5294

Epoch 6/50
----------------------------------------


Epoch 6 [Train]: 100%|████████████████████████████████████████████████| 625/625 [11:09<00:00,  1.07s/it, loss=0.033695]


Train Loss: 0.033695
Learning Rate: 5.00e-03


Epoch 6 [Val]: 100%|██████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=33.259, ssim=0.755]


Val PSNR: 35.2371, Val SSIM: 0.8054

Epoch 7/50
----------------------------------------


Epoch 7 [Train]: 100%|████████████████████████████████████████████████| 625/625 [11:11<00:00,  1.07s/it, loss=0.033535]


Train Loss: 0.033535
Learning Rate: 5.00e-03


Epoch 7 [Val]: 100%|██████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=33.434, ssim=0.758]


Val PSNR: 35.4500, Val SSIM: 0.8094

Epoch 8/50
----------------------------------------


Epoch 8 [Train]: 100%|████████████████████████████████████████████████| 625/625 [11:09<00:00,  1.07s/it, loss=0.033548]


Train Loss: 0.033548
Learning Rate: 5.00e-03


Epoch 8 [Val]: 100%|██████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=32.955, ssim=0.750]


Val PSNR: 34.8967, Val SSIM: 0.8002

Epoch 9/50
----------------------------------------


Epoch 9 [Train]: 100%|████████████████████████████████████████████████| 625/625 [11:09<00:00,  1.07s/it, loss=0.033764]


Train Loss: 0.033764
Learning Rate: 5.00e-03


Epoch 9 [Val]: 100%|██████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=33.327, ssim=0.759]


Val PSNR: 35.3366, Val SSIM: 0.8094

Epoch 10/50
----------------------------------------


Epoch 10 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:09<00:00,  1.07s/it, loss=0.033524]


Train Loss: 0.033524
Learning Rate: 5.00e-03


Epoch 10 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.57s/it, psnr=33.049, ssim=0.754]


Val PSNR: 34.9958, Val SSIM: 0.8031

Epoch 11/50
----------------------------------------


Epoch 11 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:08<00:00,  1.07s/it, loss=0.033566]


Train Loss: 0.033566
Learning Rate: 5.00e-03


Epoch 11 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=33.226, ssim=0.752]


Val PSNR: 35.2102, Val SSIM: 0.8033

Epoch 12/50
----------------------------------------


Epoch 12 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:09<00:00,  1.07s/it, loss=0.033917]


Train Loss: 0.033917
Learning Rate: 5.00e-03


Epoch 12 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.59s/it, psnr=33.077, ssim=0.751]


Val PSNR: 35.0478, Val SSIM: 0.8012

Epoch 13/50
----------------------------------------


Epoch 13 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:09<00:00,  1.07s/it, loss=0.033796]


Train Loss: 0.033796
Learning Rate: 5.00e-03


Epoch 13 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=32.324, ssim=0.741]


Val PSNR: 34.1937, Val SSIM: 0.7886

Epoch 14/50
----------------------------------------


Epoch 14 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:09<00:00,  1.07s/it, loss=0.033605]


Train Loss: 0.033605
Learning Rate: 5.00e-03


Epoch 14 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=33.456, ssim=0.760]


Val PSNR: 35.4633, Val SSIM: 0.8113

Epoch 15/50
----------------------------------------


Epoch 15 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:08<00:00,  1.07s/it, loss=0.033931]


Train Loss: 0.033931
Learning Rate: 2.50e-03


Epoch 15 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=32.572, ssim=0.737]


Val PSNR: 34.4912, Val SSIM: 0.7881

Epoch 16/50
----------------------------------------


Epoch 16 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:07<00:00,  1.07s/it, loss=0.033543]


Train Loss: 0.033543
Learning Rate: 2.50e-03


Epoch 16 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=33.211, ssim=0.750]


Val PSNR: 35.1993, Val SSIM: 0.8016

Epoch 17/50
----------------------------------------


Epoch 17 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:06<00:00,  1.07s/it, loss=0.033550]


Train Loss: 0.033550
Learning Rate: 2.50e-03


Epoch 17 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=32.973, ssim=0.750]


Val PSNR: 34.9455, Val SSIM: 0.8010

Epoch 18/50
----------------------------------------


Epoch 18 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:07<00:00,  1.07s/it, loss=0.033662]


Train Loss: 0.033662
Learning Rate: 2.50e-03


Epoch 18 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.59s/it, psnr=32.899, ssim=0.742]


Val PSNR: 34.8334, Val SSIM: 0.7924

Epoch 19/50
----------------------------------------


Epoch 19 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:06<00:00,  1.07s/it, loss=0.033770]


Train Loss: 0.033770
Learning Rate: 2.50e-03


Epoch 19 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.59s/it, psnr=33.098, ssim=0.746]


Val PSNR: 35.0699, Val SSIM: 0.7973

Epoch 20/50
----------------------------------------


Epoch 20 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:07<00:00,  1.07s/it, loss=0.033684]


Train Loss: 0.033684
Learning Rate: 2.50e-03


Epoch 20 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=33.494, ssim=0.759]


Val PSNR: 35.4979, Val SSIM: 0.8099

Epoch 21/50
----------------------------------------


Epoch 21 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:09<00:00,  1.07s/it, loss=0.033735]


Train Loss: 0.033735
Learning Rate: 2.50e-03


Epoch 21 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=33.245, ssim=0.755]


Val PSNR: 35.2581, Val SSIM: 0.8063

Epoch 22/50
----------------------------------------


Epoch 22 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:09<00:00,  1.07s/it, loss=0.033539]


Train Loss: 0.033539
Learning Rate: 2.50e-03


Epoch 22 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.57s/it, psnr=33.076, ssim=0.746]


Val PSNR: 35.0720, Val SSIM: 0.7984

Epoch 23/50
----------------------------------------


Epoch 23 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:07<00:00,  1.07s/it, loss=0.033591]


Train Loss: 0.033591
Learning Rate: 2.50e-03


Epoch 23 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.57s/it, psnr=33.306, ssim=0.748]


Val PSNR: 35.3077, Val SSIM: 0.8007

Epoch 24/50
----------------------------------------


Epoch 24 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:07<00:00,  1.07s/it, loss=0.033865]


Train Loss: 0.033865
Learning Rate: 2.50e-03


Epoch 24 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.57s/it, psnr=33.436, ssim=0.753]


Val PSNR: 35.4432, Val SSIM: 0.8054

Epoch 25/50
----------------------------------------


Epoch 25 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:07<00:00,  1.07s/it, loss=0.033729]


Train Loss: 0.033729
Learning Rate: 2.50e-03


Epoch 25 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.57s/it, psnr=33.293, ssim=0.752]


Val PSNR: 35.3146, Val SSIM: 0.8045

Epoch 26/50
----------------------------------------


Epoch 26 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:06<00:00,  1.07s/it, loss=0.033857]


Train Loss: 0.033857
Learning Rate: 2.50e-03


Epoch 26 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.57s/it, psnr=33.232, ssim=0.754]


Val PSNR: 35.2281, Val SSIM: 0.8049

Epoch 27/50
----------------------------------------


Epoch 27 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:07<00:00,  1.07s/it, loss=0.033702]


Train Loss: 0.033702
Learning Rate: 2.50e-03


Epoch 27 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.57s/it, psnr=33.156, ssim=0.749]


Val PSNR: 35.1577, Val SSIM: 0.8010

Epoch 28/50
----------------------------------------


Epoch 28 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:07<00:00,  1.07s/it, loss=0.033674]


Train Loss: 0.033674
Learning Rate: 2.50e-03


Epoch 28 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.57s/it, psnr=32.927, ssim=0.740]


Val PSNR: 34.8885, Val SSIM: 0.7917

Epoch 29/50
----------------------------------------


Epoch 29 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:09<00:00,  1.07s/it, loss=0.033858]


Train Loss: 0.033858
Learning Rate: 2.50e-03


Epoch 29 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.57s/it, psnr=33.017, ssim=0.743]


Val PSNR: 35.0044, Val SSIM: 0.7953

Epoch 30/50
----------------------------------------


Epoch 30 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:09<00:00,  1.07s/it, loss=0.033920]


Train Loss: 0.033920
Learning Rate: 1.25e-03


Epoch 30 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.58s/it, psnr=33.007, ssim=0.746]


Val PSNR: 35.0004, Val SSIM: 0.7982

Epoch 31/50
----------------------------------------


Epoch 31 [Train]: 100%|███████████████████████████████████████████████| 625/625 [11:09<00:00,  1.07s/it, loss=0.033885]


Train Loss: 0.033885
Learning Rate: 1.25e-03


Epoch 31 [Val]: 100%|█████████████████████████████████████████| 20/20 [00:31<00:00,  1.59s/it, psnr=33.007, ssim=0.750]


Val PSNR: 34.9752, Val SSIM: 0.8004

Epoch 32/50
----------------------------------------


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x0000028126484360>  1.09s/it, loss=0.033948]
Traceback (most recent call last):
  File "C:\Users\Guill\miniconda3\envs\phd_env_conda\Lib\site-packages\torch\utils\data\dataloader.py", line 1650, in __del__
    self._shutdown_workers()
  File "C:\Users\Guill\miniconda3\envs\phd_env_conda\Lib\site-packages\torch\utils\data\dataloader.py", line 1614, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "C:\Users\Guill\miniconda3\envs\phd_env_conda\Lib\multiprocessing\process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Guill\miniconda3\envs\phd_env_conda\Lib\multiprocessing\popen_spawn_win32.py", line 112, in wait
    res = _winapi.WaitForSingleObject(int(self._handle), msecs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt: 
Epoch 32 [Train]:  74%|██████████████████████████████████▋      