In [1]:
import sys
import os

# Add the parent directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
os.chdir("..")

In [2]:
import torch
import matplotlib.pyplot as plt
import climex_utils as cu
import train_prob_unet_model as tm  
from prob_unet import ProbabilisticUNet
# from prob_unet_utils import plot_losses, plot_losses_mae
import pickle
import numpy as np
import random
import os

In [3]:
# Training model with afcrps
import torch
import matplotlib.pyplot as plt
import climex_utils as cu
import train_prob_unet_model as tm  
from prob_unet import ProbabilisticUNet
# from prob_unet_utils import plot_losses, plot_losses_mae
import pickle
import numpy as np
import random
import os

if __name__ == "__main__":

    def set_seed(seed):
        random.seed(seed) 
        np.random.seed(seed)  
        torch.manual_seed(seed) 
        torch.cuda.manual_seed(seed)  
        torch.cuda.manual_seed_all(seed)  
        torch.backends.cudnn.deterministic = True  
        torch.backends.cudnn.benchmark = False  
        os.environ['PYTHONHASHSEED'] = str(seed)

    # -- 1) Set seed for reproducibility
    set_seed(42)  

    # -- 2) Importing all required arguments
    args = tm.get_args()
    args.lowres_scale = 16
    args.batch_size = 32
    args.num_epochs = 30
    args.years_train = range(1960, 2020)
    args.years_val = range(2021, 2030)

    # Initialize the Probabilistic UNet model
    probunet_model = ProbabilisticUNet(
        input_channels=len(args.variables),
        num_classes=len(args.variables),
        latent_dim=16,
        num_filters=[32, 64, 128, 256],
        model_channels=32,
        channel_mult=[1, 2, 4, 8],
        beta_0=0.0,
        beta_1=0.0,
        beta_2=0.0  
    ).to(args.device)

    # -- 3) Prepare datasets
    dataset_train = cu.climex2torch(
        datadir=args.datadir,
        years=args.years_train,
        variables=args.variables,
        coords=args.coords,
        lowres_scale=args.lowres_scale,
        type="lrinterp_to_residuals",
        transfo=True
    )
    dataset_val = cu.climex2torch(
        datadir=args.datadir,
        years=args.years_val,
        variables=args.variables,
        coords=args.coords,
        lowres_scale=args.lowres_scale,
        type="lrinterp_to_residuals",
        transfo=True
    )
    dataset_test = cu.climex2torch(
        datadir=args.datadir,
        years=args.years_test,
        variables=args.variables,
        coords=args.coords,
        lowres_scale=args.lowres_scale,
        type="lrinterp_to_residuals",
        transfo=True
    )

    # -- 4) Build DataLoaders
    dataloader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=0
    )
    dataloader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0
    )
    dataloader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0
    )
    dataloader_test_random = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=2,
        shuffle=True,
        num_workers=0
    )

    # -- 5) Define optimizer
    optimizer = args.optimizer(params=probunet_model.parameters(), lr=args.lr)
    # Example alternative:
    # optimizer = torch.optim.Adam(probunet_model.parameters(), lr=args.lr, weight_decay=1e-4)

    # -- 6) We track CRPS, KL each epoch for train/val
    train_crps_list, train_kl_list = [], []
    val_crps_list,   val_kl_list = [], []

    # For convenience, we keep your adaptive betas:
    beta_0 = 1.0
    beta_1 = 0.0
    warmup_epochs = 20
    max_beta_1 = 0.001

    print(f"Probabilistic Unet Latent dim: {probunet_model.latent_dim}")

    # -- 7) Main training loop
    for epoch in range(1, args.num_epochs + 1):
        # Set model betas
        probunet_model.beta_0 = beta_0
        probunet_model.beta_1 = beta_1

        print(f"Epoch {epoch}/{args.num_epochs} - beta_0: {probunet_model.beta_0:.4f}, "
              f"beta_1: {probunet_model.beta_1:.4f}")

        # 7a) Train for one epoch (returns mean_crps, mean_kl)
        train_crps, train_kl = tm.train_probunet_step(
            model=probunet_model,
            dataloader=dataloader_train,
            optimizer=optimizer,
            epoch=epoch,
            num_epochs=args.num_epochs,
            device=args.device,        
            ensemble_size=15    # how many samples per forward pass
        )
        train_crps_list.append(train_crps)
        train_kl_list.append(train_kl)

        # 7b) Evaluate on validation data
        val_crps, val_kl = tm.eval_probunet_model(
            model=probunet_model,
            dataloader=dataloader_val,
            device=args.device,
            ensemble_size=5
        )
        val_crps_list.append(val_crps)
        val_kl_list.append(val_kl)

        print(f"[Train] CRPS={train_crps:.4f}, KL={train_kl:.4f}| "
              f"[Val] CRPS={val_crps:.4f}, KL={val_kl:.4f}")
        
        # 7c) Update betas with gradual annealing
        if epoch <= warmup_epochs:
            # Warmup phase: no KL penalty
            beta_0, beta_1 = 1.0, 0.0
        else:
            # Annealing phase: gradually increase beta_1 from 0 to max_beta_1
            # Progress goes from 0 (just after warmup) to 1 (at final epoch)
            total_annealing_epochs = args.num_epochs - warmup_epochs
            current_annealing_epoch = epoch - warmup_epochs
            progress = min(current_annealing_epoch / total_annealing_epochs, 1.0)
            
            beta_0 = 1.0
            beta_1 = progress * max_beta_1
            
            print(f"  → Annealing progress: {progress:.2%} | beta_1 = {beta_1:.4f}")

        # 7d) Example sampling from the model for sanity checks
        test_batch = next(iter(dataloader_test_random))

        # Residual predictions
        residual_preds, (fig, axs) = tm.sample_residual_probunet_model(
            model=probunet_model,
            dataloader=dataloader_test_random,
            epoch=epoch,
            device=args.device,
            batch=test_batch
        )
        fig.savefig(f"{args.plotdir}/epoch{epoch}_residuals.png", dpi=300)
        plt.close(fig)

        fig_difs, axs_difs = dataset_test.plot_residual_differences(
            residual_preds=residual_preds,
            timestamps_float=test_batch['timestamps_float'][:2],
            epoch=epoch,
            N=2, 
            num_samples=3
        )
        fig_difs.savefig(f"{args.plotdir}/epoch{epoch}_res_difs.png", dpi=300)
        plt.close(fig_difs)

        # Full reconstruction
        samples, (fig, axs) = tm.sample_probunet_model(
            model=probunet_model,
            dataloader=dataloader_test_random,
            epoch=epoch,
            device=args.device,
            batch=test_batch
        )
        fig.savefig(f"{args.plotdir}/epoch{epoch}_reconstructed.png", dpi=300)
        plt.close(fig)

    # -- 8) Save final model weights
    torch.save(probunet_model.state_dict(),
               f"{args.plotdir}/probunet_model_lat_dim_{probunet_model.latent_dim}.pth")

    # -- 9) Save losses for analysis
    losses_to_save = {
        "train_crps": train_crps_list,
        "train_kl":   train_kl_list,
        "val_crps":   val_crps_list,
        "val_kl":     val_kl_list,
    }
    with open(f"{args.plotdir}/losses.pkl", "wb") as f:
        pickle.dump(losses_to_save, f)

    # -- NEW: Analyze residual contribution
    print("\n" + "="*60)
    print("ANALYZING RESIDUAL CONTRIBUTION")
    print("="*60)
    tm.analyze_residual_contribution(
        model=probunet_model,
        dataloader=dataloader_test,
        device=args.device,
        num_samples=20
    )
  
    epochs = np.arange(1, args.num_epochs+1)
    plt.plot(epochs, train_crps_list, label='Train CRPS')
    plt.plot(epochs, val_crps_list,   label='Val CRPS', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('CRPS')
    plt.legend()
    plt.title('Training and Validation CRPS')
    plt.savefig(f"{args.plotdir}/CRPS_curve.png", dpi=300)
    plt.close()


Opening and lazy loading netCDF files
Loading dataset into memory
Converting xarray Dataset to Pytorch tensor

##########################################
############ PROCESSING DONE #############
##########################################

Opening and lazy loading netCDF files
Loading dataset into memory
Converting xarray Dataset to Pytorch tensor

##########################################
############ PROCESSING DONE #############
##########################################

Opening and lazy loading netCDF files
Loading dataset into memory
Converting xarray Dataset to Pytorch tensor

##########################################
############ PROCESSING DONE #############
##########################################

Probabilistic Unet Latent dim: 16
Epoch 1/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 1/30:   0%|          | 0/685 [00:00<?, ?it/s]

Computing statistics for standardization


Train :: Epoch: 1/30: 100%|██████████| 685/685 [02:59<00:00,  3.81it/s, Loss: 0.1325]
:: Evaluation :::   1%|          | 1/103 [00:00<00:20,  4.99it/s]

Computing statistics for standardization


:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.07it/s]


[Train] CRPS=0.1476, KL=1790.5485| [Val] CRPS=0.1347, KL=2775.1663
Epoch 2/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 2/30: 100%|██████████| 685/685 [02:58<00:00,  3.84it/s, Loss: 0.1394]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.59it/s]


[Train] CRPS=0.1256, KL=2582.2215| [Val] CRPS=0.1284, KL=2572.4964
Epoch 3/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 3/30: 100%|██████████| 685/685 [02:58<00:00,  3.84it/s, Loss: 0.1236]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.57it/s]


[Train] CRPS=0.1217, KL=2454.7891| [Val] CRPS=0.1269, KL=2850.2980
Epoch 4/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 4/30: 100%|██████████| 685/685 [02:58<00:00,  3.85it/s, Loss: 0.1146]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.64it/s]


[Train] CRPS=0.1194, KL=2239.7311| [Val] CRPS=0.1249, KL=2108.2311
Epoch 5/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 5/30: 100%|██████████| 685/685 [02:58<00:00,  3.83it/s, Loss: 0.1268]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.72it/s]


[Train] CRPS=0.1178, KL=1901.8787| [Val] CRPS=0.1237, KL=1883.3967
Epoch 6/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 6/30: 100%|██████████| 685/685 [02:58<00:00,  3.84it/s, Loss: 0.1188]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.69it/s]


[Train] CRPS=0.1165, KL=1770.3983| [Val] CRPS=0.1223, KL=2709.7864
Epoch 7/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 7/30: 100%|██████████| 685/685 [02:58<00:00,  3.85it/s, Loss: 0.1151]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.60it/s]


[Train] CRPS=0.1154, KL=1899.7130| [Val] CRPS=0.1227, KL=1901.7465
Epoch 8/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 8/30: 100%|██████████| 685/685 [02:58<00:00,  3.84it/s, Loss: 0.1198]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.56it/s]


[Train] CRPS=0.1143, KL=1678.5079| [Val] CRPS=0.1214, KL=1503.0295
Epoch 9/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 9/30: 100%|██████████| 685/685 [02:58<00:00,  3.84it/s, Loss: 0.1095]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.59it/s]


[Train] CRPS=0.1134, KL=1683.9382| [Val] CRPS=0.1209, KL=1893.5939
Epoch 10/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 10/30: 100%|██████████| 685/685 [02:58<00:00,  3.84it/s, Loss: 0.1105]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.52it/s]


[Train] CRPS=0.1125, KL=1597.7627| [Val] CRPS=0.1205, KL=1802.4214
Epoch 11/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 11/30: 100%|██████████| 685/685 [02:58<00:00,  3.84it/s, Loss: 0.1126]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.59it/s]


[Train] CRPS=0.1116, KL=1526.5787| [Val] CRPS=0.1201, KL=1633.6973
Epoch 12/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 12/30: 100%|██████████| 685/685 [02:59<00:00,  3.82it/s, Loss: 0.1058]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.29it/s]


[Train] CRPS=0.1108, KL=1404.7101| [Val] CRPS=0.1198, KL=1701.8596
Epoch 13/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 13/30: 100%|██████████| 685/685 [02:59<00:00,  3.82it/s, Loss: 0.1103]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.64it/s]


[Train] CRPS=0.1100, KL=1377.8187| [Val] CRPS=0.1204, KL=1638.4421
Epoch 14/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 14/30: 100%|██████████| 685/685 [02:58<00:00,  3.83it/s, Loss: 0.1051]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.56it/s]


[Train] CRPS=0.1093, KL=1466.7424| [Val] CRPS=0.1201, KL=2138.6040
Epoch 15/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 15/30: 100%|██████████| 685/685 [02:58<00:00,  3.84it/s, Loss: 0.1058]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.43it/s]


[Train] CRPS=0.1085, KL=1172.1625| [Val] CRPS=0.1197, KL=1648.1453
Epoch 16/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 16/30: 100%|██████████| 685/685 [02:58<00:00,  3.84it/s, Loss: 0.1121]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.55it/s]


[Train] CRPS=0.1078, KL=1102.5997| [Val] CRPS=0.1200, KL=1356.9555
Epoch 17/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 17/30: 100%|██████████| 685/685 [02:58<00:00,  3.84it/s, Loss: 0.1072]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.57it/s]


[Train] CRPS=0.1070, KL=1138.4439| [Val] CRPS=0.1200, KL=1375.0674
Epoch 18/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 18/30: 100%|██████████| 685/685 [02:57<00:00,  3.85it/s, Loss: 0.1118]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.74it/s]


[Train] CRPS=0.1064, KL=975.4800| [Val] CRPS=0.1194, KL=1487.2089
Epoch 19/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 19/30: 100%|██████████| 685/685 [02:57<00:00,  3.86it/s, Loss: 0.1035]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.78it/s]


[Train] CRPS=0.1058, KL=1045.3667| [Val] CRPS=0.1204, KL=1332.9770
Epoch 20/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 20/30: 100%|██████████| 685/685 [02:57<00:00,  3.85it/s, Loss: 0.1091]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.70it/s]


[Train] CRPS=0.1051, KL=838.9338| [Val] CRPS=0.1194, KL=897.7815
Epoch 21/30 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 21/30: 100%|██████████| 685/685 [02:57<00:00,  3.85it/s, Loss: 0.1109]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.60it/s]


[Train] CRPS=0.1046, KL=841.8784| [Val] CRPS=0.1204, KL=781.3664
  → Annealing progress: 10.00% | beta_1 = 0.0001
Epoch 22/30 - beta_0: 1.0000, beta_1: 0.0001


Train :: Epoch: 22/30: 100%|██████████| 685/685 [02:58<00:00,  3.85it/s, Loss: 0.1009]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.69it/s]


[Train] CRPS=0.1041, KL=1.1062| [Val] CRPS=0.1206, KL=0.0896
  → Annealing progress: 20.00% | beta_1 = 0.0002
Epoch 23/30 - beta_0: 1.0000, beta_1: 0.0002


Train :: Epoch: 23/30: 100%|██████████| 685/685 [02:57<00:00,  3.85it/s, Loss: 0.0953]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.49it/s]


[Train] CRPS=0.1035, KL=0.0649| [Val] CRPS=0.1198, KL=0.0506
  → Annealing progress: 30.00% | beta_1 = 0.0003
Epoch 24/30 - beta_0: 1.0000, beta_1: 0.0003


Train :: Epoch: 24/30: 100%|██████████| 685/685 [02:57<00:00,  3.85it/s, Loss: 0.1043]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.68it/s]


[Train] CRPS=0.1030, KL=0.0451| [Val] CRPS=0.1205, KL=0.0892
  → Annealing progress: 40.00% | beta_1 = 0.0004
Epoch 25/30 - beta_0: 1.0000, beta_1: 0.0004


Train :: Epoch: 25/30: 100%|██████████| 685/685 [02:57<00:00,  3.85it/s, Loss: 0.0957]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.70it/s]


[Train] CRPS=0.1024, KL=0.0367| [Val] CRPS=0.1199, KL=0.0302
  → Annealing progress: 50.00% | beta_1 = 0.0005
Epoch 26/30 - beta_0: 1.0000, beta_1: 0.0005


Train :: Epoch: 26/30: 100%|██████████| 685/685 [02:57<00:00,  3.85it/s, Loss: 0.1061]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.69it/s]


[Train] CRPS=0.1020, KL=0.0266| [Val] CRPS=0.1208, KL=0.0187
  → Annealing progress: 60.00% | beta_1 = 0.0006
Epoch 27/30 - beta_0: 1.0000, beta_1: 0.0006


Train :: Epoch: 27/30: 100%|██████████| 685/685 [02:57<00:00,  3.85it/s, Loss: 0.1083]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.67it/s]


[Train] CRPS=0.1015, KL=0.0181| [Val] CRPS=0.1203, KL=0.0188
  → Annealing progress: 70.00% | beta_1 = 0.0007
Epoch 28/30 - beta_0: 1.0000, beta_1: 0.0007


Train :: Epoch: 28/30: 100%|██████████| 685/685 [02:57<00:00,  3.85it/s, Loss: 0.1019]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.57it/s]


[Train] CRPS=0.1011, KL=0.0188| [Val] CRPS=0.1202, KL=0.0118
  → Annealing progress: 80.00% | beta_1 = 0.0008
Epoch 29/30 - beta_0: 1.0000, beta_1: 0.0008


Train :: Epoch: 29/30: 100%|██████████| 685/685 [02:57<00:00,  3.85it/s, Loss: 0.0982]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.68it/s]


[Train] CRPS=0.1007, KL=0.0142| [Val] CRPS=0.1209, KL=0.0169
  → Annealing progress: 90.00% | beta_1 = 0.0009
Epoch 30/30 - beta_0: 1.0000, beta_1: 0.0009


Train :: Epoch: 30/30: 100%|██████████| 685/685 [02:57<00:00,  3.85it/s, Loss: 0.1086]
:: Evaluation ::: 100%|██████████| 103/103 [00:06<00:00, 16.63it/s]


[Train] CRPS=0.1002, KL=0.0157| [Val] CRPS=0.1215, KL=0.0267
  → Annealing progress: 100.00% | beta_1 = 0.0010

ANALYZING RESIDUAL CONTRIBUTION
Computing statistics for standardization

[ANALYSIS] Residual Contribution:
  Mean |predicted_residual|: 0.3149
  Mean |true_residual|: 1.3834
  Error (lrinterp only): 1.3834
  Error (lrinterp + model): 0.9807
  Improvement: 0.4027
  Improvement %: 29.11%


In [None]:
# Training model with wmse-ms-ssim
if __name__ == "__main__":

    def set_seed(seed):
        random.seed(seed) 
        np.random.seed(seed)  
        torch.manual_seed(seed) 
        torch.cuda.manual_seed(seed)  
        torch.cuda.manual_seed_all(seed)  
        torch.backends.cudnn.deterministic = True  
        torch.backends.cudnn.benchmark = False  
        os.environ['PYTHONHASHSEED'] = str(seed)

    # -- 1) Set seed for reproducibility
    set_seed(42)  

    # -- 2) Importing all required arguments
    args = tm.get_args()
    args.lowres_scale = 16
    args.batch_size = 32
    args.num_epochs = 10

    # Initialize the Probabilistic UNet model
    probunet_model = ProbabilisticUNet(
        input_channels=len(args.variables),
        num_classes=len(args.variables),
        latent_dim=32,
        num_filters=[32, 64, 128, 256],
        model_channels=32,
        channel_mult=[1, 2, 4, 8],
        beta_0=0.0,
        beta_1=0.0,
        beta_2=0.0  
    ).to(args.device)

    # -- 3) Prepare datasets
    dataset_train = cu.climex2torch(
        datadir=args.datadir,
        years=args.years_train,
        variables=args.variables,
        coords=args.coords,
        lowres_scale=args.lowres_scale,
        type="lrinterp_to_residuals",
        transfo=True
    )
    dataset_val = cu.climex2torch(
        datadir=args.datadir,
        years=args.years_val,
        variables=args.variables,
        coords=args.coords,
        lowres_scale=args.lowres_scale,
        type="lrinterp_to_residuals",
        transfo=True
    )
    dataset_test = cu.climex2torch(
        datadir=args.datadir,
        years=args.years_test,
        variables=args.variables,
        coords=args.coords,
        lowres_scale=args.lowres_scale,
        type="lrinterp_to_residuals",
        transfo=True
    )

    # -- 4) Build DataLoaders
    dataloader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=0
    )
    dataloader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0
    )
    dataloader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0
    )
    dataloader_test_random = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=2,
        shuffle=True,
        num_workers=0
    )

    # -- 5) Define optimizer
    optimizer = args.optimizer(params=probunet_model.parameters(), lr=args.lr)
    # Example alternative:
    # optimizer = torch.optim.Adam(probunet_model.parameters(), lr=args.lr, weight_decay=1e-4)

    # -- 6) We track CRPS, KL each epoch for train/val
    train_crps_list, train_kl_list = [], []
    val_crps_list,   val_kl_list = [], []

    # For convenience, we keep your adaptive betas:
    beta_0 = 1.0
    beta_1 = 0.0
    warmup_epochs = 2

    print(f"Probabilistic Unet Latent dim: {probunet_model.latent_dim}")

    # -- 7) Main training loop
    for epoch in range(1, args.num_epochs + 1):
        # Set model betas
        probunet_model.beta_0 = beta_0
        probunet_model.beta_1 = beta_1

        print(f"Epoch {epoch}/{args.num_epochs} - beta_0: {probunet_model.beta_0:.4f}, "
              f"beta_1: {probunet_model.beta_1:.4f}")

        # 7a) Train for one epoch (returns mean_crps, mean_kl)
        train_crps, train_kl, train_wmse, train_msssim = tm.train_probunet_step(
            model=probunet_model,
            dataloader=dataloader_train,
            optimizer=optimizer,
            epoch=epoch,
            num_epochs=args.num_epochs,
            device=args.device,       
            ensemble_size=1    # how many samples per forward pass
        )
        train_crps_list.append(train_crps)
        train_kl_list.append(train_kl)

        # 7b) Update betas after warmup
        if epoch > warmup_epochs:
            beta_0 = 1.0 / (train_crps + 1e-7)
            # beta_0 = 1.0
            beta_1 = 1.0 / (train_kl   + 1e-7)
        else:
            beta_0, beta_1 = 1.0, 0.0

        # 7c) Evaluate on validation data
        val_crps, val_kl, val_wmse, val_msssim = tm.eval_probunet_model(
            model=probunet_model,
            dataloader=dataloader_val,
            device=args.device,
            ensemble_size=1
        )
        val_crps_list.append(val_crps)
        val_kl_list.append(val_kl)

        print(f"[Train] recon_loss={train_crps:.4f}, KL={train_kl:.4f}, WMSE={train_wmse:.4f}, MSSSIM={train_msssim:.4f} | "
              f"[Val] recon_loss={val_crps:.4f}, KL={val_kl:.4f}, WMSE={val_wmse:.4f}, MSSSIM={val_msssim:.4f}")
        # 7d) Example sampling from the model for sanity checks
        test_batch = next(iter(dataloader_test_random))

        # Residual predictions
        residual_preds, (fig, axs) = tm.sample_residual_probunet_model(
            model=probunet_model,
            dataloader=dataloader_test_random,
            epoch=epoch,
            device=args.device,
            batch=test_batch
        )
        fig.savefig(f"{args.plotdir}/epoch{epoch}_residuals.png", dpi=300)
        plt.close(fig)

        fig_difs, axs_difs = dataset_test.plot_residual_differences(
            residual_preds=residual_preds,
            timestamps_float=test_batch['timestamps_float'][:2],
            epoch=epoch,
            N=2, 
            num_samples=3
        )
        fig_difs.savefig(f"{args.plotdir}/epoch{epoch}_res_difs.png", dpi=300)
        plt.close(fig_difs)

        # Full reconstruction
        samples, (fig, axs) = tm.sample_probunet_model(
            model=probunet_model,
            dataloader=dataloader_test_random,
            epoch=epoch,
            device=args.device,
            batch=test_batch
        )
        fig.savefig(f"{args.plotdir}/epoch{epoch}_reconstructed.png", dpi=300)
        plt.close(fig)

    # -- 8) Save final model weights
    torch.save(probunet_model.state_dict(),
               f"{args.plotdir}/probunet_model_lat_dim_{probunet_model.latent_dim}.pth")

    # -- 9) Save losses for analysis
    losses_to_save = {
        "train_crps": train_crps_list,
        "train_kl":   train_kl_list,
        "val_crps":   val_crps_list,
        "val_kl":     val_kl_list,
    }
    with open(f"{args.plotdir}/losses.pkl", "wb") as f:
        pickle.dump(losses_to_save, f)

    # -- NEW: Analyze residual contribution
    print("\n" + "="*60)
    print("ANALYZING RESIDUAL CONTRIBUTION")
    print("="*60)
    tm.analyze_residual_contribution(
        model=probunet_model,
        dataloader=dataloader_test,
        device=args.device,
        num_samples=20
    )
  
    epochs = np.arange(1, args.num_epochs+1)
    plt.plot(epochs, train_crps_list, label='Train CRPS')
    plt.plot(epochs, val_crps_list,   label='Val CRPS', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('CRPS')
    plt.legend()
    plt.title('Training and Validation CRPS')
    plt.savefig(f"{args.plotdir}/CRPS_curve.png", dpi=300)
    plt.close()

Opening and lazy loading netCDF files
Loading dataset into memory
Converting xarray Dataset to Pytorch tensor

##########################################
############ PROCESSING DONE #############
##########################################

Opening and lazy loading netCDF files
Loading dataset into memory
Converting xarray Dataset to Pytorch tensor

##########################################
############ PROCESSING DONE #############
##########################################

Opening and lazy loading netCDF files
Loading dataset into memory
Converting xarray Dataset to Pytorch tensor

##########################################
############ PROCESSING DONE #############
##########################################

Probabilistic Unet Latent dim: 16
Epoch 1/10 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 1/10:   0%|          | 1/343 [00:00<01:06,  5.13it/s]

Computing statistics for standardization


Train :: Epoch: 1/10: 100%|██████████| 343/343 [00:57<00:00,  5.99it/s, Loss: 0.1232]
:: Evaluation :::   2%|▏         | 2/92 [00:00<00:05, 16.90it/s]

Computing statistics for standardization


:: Evaluation ::: 100%|██████████| 92/92 [00:05<00:00, 17.88it/s]


[Train] recon_loss=0.1687, KL=1406.9236, WMSE=0.0016, MSSSIM=0.1687 | [Val] recon_loss=0.1134, KL=1476.5981, WMSE=0.0010, MSSSIM=0.1134
Epoch 2/10 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 2/10: 100%|██████████| 343/343 [00:57<00:00,  6.02it/s, Loss: 0.1168]
:: Evaluation ::: 100%|██████████| 92/92 [00:05<00:00, 18.30it/s]


[Train] recon_loss=0.0938, KL=1420.8607, WMSE=0.0009, MSSSIM=0.0938 | [Val] recon_loss=0.0980, KL=1178.9982, WMSE=0.0009, MSSSIM=0.0980
Epoch 3/10 - beta_0: 1.0000, beta_1: 0.0000


Train :: Epoch: 3/10:  98%|█████████▊| 335/343 [00:55<00:01,  6.02it/s, Loss: 0.0826]


KeyboardInterrupt: 

In [1]:
import torch
import climex_utils as cu
import train_prob_unet_model as tm  
from prob_unet import ProbabilisticUNet
import numpy as np
import os
from tqdm import tqdm

In [2]:

def load_model(model_path, args, latent_dim, num_filters, model_channels, channel_mult):
    """
    Load a trained ProbabilisticUNet model with specified latent_dim, num_filters, model_channels, and channel_mult.
    
    Args:
        model_path: Path to the saved model weights.
        args: Arguments for loading.
        latent_dim: Latent space dimension for the ProbabilisticUNet.
        num_filters: List specifying the number of filters at each U-Net level.
        model_channels: Base number of feature maps in the U-Net.
        channel_mult: Multipliers for feature maps at each resolution level in the U-Net.

    Returns:
        Loaded model in evaluation mode.
    """
    model = ProbabilisticUNet(
        input_channels=len(args.variables),
        num_classes=len(args.variables),
        latent_dim=latent_dim,
        num_filters=num_filters,
        model_channels=model_channels,
        channel_mult=channel_mult,
        beta_0=0.0,
        beta_1=0.0,
        beta_2=0.0  
    ).to(args.device)
    
    model.load_state_dict(torch.load(model_path))
    model.eval()  
    return model

In [3]:
def generate_samples(model, dataloader_test, num_samples=20, save_dir="./test_predictions"):
    """
    Generate predictions for the entire test set and save the samples to .npy files.
    
    Args:
        model: Trained ProbabilisticUNet model.
        dataloader_test: DataLoader for the test set.
        num_samples: Number of samples per test sample.
        save_dir: Directory where .npy files will be saved.
    """
    os.makedirs(save_dir, exist_ok=True)
    all_preds = []  # Store predictions for all test samples

    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader_test, desc="Generating predictions")):
            inputs = batch['inputs'].to(args.device)  
            lrinterp = batch['lrinterp'].to(args.device) 
            timestamps = batch['timestamps'].unsqueeze(dim=1).to(args.device)
            hr_targets = batch['hr']  
            
            sample_preds = []  # Store 50 predictions for this batch
            for _ in range(num_samples):
                preds = model(inputs, t=timestamps, training=False)  
                full_recon = dataloader_test.dataset.residual_to_hr(preds.cpu(), lrinterp.cpu())  
                sample_preds.append(full_recon.numpy())  # Convert to NumPy array

            # Stack predictions along a new axis (shape: [batch_size, 50, C, H, W])
            sample_preds = np.stack(sample_preds, axis=1)
            all_preds.append(sample_preds)

    # Concatenate all batches (shape: [N, 50, C, H, W])
    all_preds = np.concatenate(all_preds, axis=0)
    save_path = os.path.join(save_dir, "predictions.npy")
    np.save(save_path, all_preds)  # Save to .npy file
    print(f"Predictions saved to {save_path}")

In [4]:
args = tm.get_args()

dataset_test = cu.climex2torch(
        datadir=args.datadir,
        years=args.years_test,
        variables=args.variables,
        coords=args.coords,
        lowres_scale=args.lowres_scale,
        type="lrinterp_to_residuals",
        transfo=True
    )

dataloader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0
    )


model_configs = [
    # {
    #     "model_path": "./results/plots/01/05/202521:51:55/probunet_model_lat_dim_32.pth",
    #     "latent_dim": 32,
    #     "num_filters": [32, 64, 128, 256],
    #     "model_channels": 32,
    #     "channel_mult": [1, 2, 4, 8]
    # },
    # {
    #     "model_path": "./results/plots/01/05/202522:14:38/probunet_model_lat_dim_64.pth",
    #     "latent_dim": 64,
    #     "num_filters": [32, 64, 128, 256],
    #     "model_channels": 32,
    #     "channel_mult": [1, 2, 4, 8]
    # },
    # {
    #     "model_path": "./results/plots/01/06/202500:40:45/probunet_model_lat_dim_8.pth",
    #     "latent_dim": 8,
    #     "num_filters": [16, 64, 128, 256],
    #     "model_channels": 16,
    #     "channel_mult": [1, 4, 8, 16]
    # },
    # {
    #     "model_path": "./results/plots/01/06/202501:00:11/probunet_model_lat_dim_16.pth",
    #     "latent_dim": 16,
    #     "num_filters": [16, 64, 128, 256],
    #     "model_channels": 16,
    #     "channel_mult": [1, 4, 8, 16]
    # },
    # {
    #     "model_path": "./results/plots/01/07/202518:33:16/probunet_model_lat_dim_32.pth",
    #     "latent_dim": 32,
    #     "num_filters": [32, 64, 128, 256],
    #     "model_channels": 32,
    #     "channel_mult": [1, 2, 4, 8]
    # },
    # {
    #     "model_path": "./results/plots/01/07/202521:03:10/probunet_model_lat_dim_64.pth",
    #     "latent_dim": 64,
    #     "num_filters": [32, 64, 128, 256],
    #     "model_channels": 32,
    #     "channel_mult": [1, 2, 4, 8]
    # },
    # {
    #     "model_path": "./results/plots/01/07/202521:29:20/probunet_model_lat_dim_8.pth",
    #     "latent_dim": 8,
    #     "num_filters": [16, 64, 128, 256],
    #     "model_channels": 16,
    #     "channel_mult": [1, 4, 8, 16]
    # },
    # {
    #     "model_path": "./results/plots/01/07/202522:01:12/probunet_model_lat_dim_16.pth",
    #     "latent_dim": 16,
    #     "num_filters": [16, 64, 128, 256],
    #     "model_channels": 16,
    #     "channel_mult": [1, 4, 8, 16]
    # },
    # {
    #     "model_path": "./results/plots/01/07/202522:21:41/probunet_model_lat_dim_64.pth",
    #     "latent_dim": 64,
    #     "num_filters": [64, 128, 256, 512],
    #     "model_channels": 64,
    #     "channel_mult": [1, 2, 4, 8]
    # },
    # {
    #     "model_path": "./results/plots/01/08/202512:38:04/probunet_model_lat_dim_32.pth",
    #     "latent_dim": 32,
    #     "num_filters": [64, 128, 256, 512],
    #     "model_channels": 64,
    #     "channel_mult": [1, 2, 4, 8]
    # },
    # {
    #     "model_path": "./results/plots/01/08/202513:41:58/probunet_model_lat_dim_64.pth",
    #     "latent_dim": 64,
    #     "num_filters": [64, 128, 256, 512],
    #     "model_channels": 64,
    #     "channel_mult": [1, 2, 4, 8]
    # },
    {
        "model_path": "./results/plots/01/16/202511:40:11/probunet_model_lat_dim_16.pth",
        "latent_dim": 16,
        "num_filters": [16, 64, 128, 256],
        "model_channels": 16,
        "channel_mult": [1, 4, 8, 16]
    },
    
]

for config in model_configs:
    print(f"Testing model from {config['model_path']}")
    
    # Load the model
    probunet_model = load_model(
        model_path=config["model_path"],
        args=args,
        latent_dim=config["latent_dim"],
        num_filters=config["num_filters"],
        model_channels=config["model_channels"],
        channel_mult=config["channel_mult"]
    )

    # Generate predictions and save them
    generate_samples(probunet_model, dataloader_test, num_samples=20, save_dir=f"./test_predictions/{config['latent_dim']}_{config['model_channels']}_Region_64_by_64")


Opening and lazy loading netCDF files
Loading dataset into memory
Converting xarray Dataset to Pytorch tensor

##########################################
############ PROCESSING DONE #############
##########################################

Testing model from ./results/plots/01/16/202511:40:11/probunet_model_lat_dim_16.pth


Generating predictions:   0%|                                                                                                                                      | 0/92 [00:00<?, ?it/s]

Computing statistics for standardization


Generating predictions: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 92/92 [00:24<00:00,  3.83it/s]


Predictions saved to ./test_predictions/16_16_Region_64_by_64/predictions.npy


In [None]:
# Analyze residual contribution after training
import train_prob_unet_model as tm

# After training is complete, analyze how much the model actually helps
tm.analyze_residual_contribution(
    model=probunet_model,
    dataloader=dataloader_test,
    device=args.device,
    num_samples=20  # Test on 20 samples
)