In [1]:
# ! pip install scikit-image
# ! pip install torch torchmetrics zarr torchvision
# ! nvidia-smi

# Import Packages

In [None]:
import torch
from tqdm import tqdm 
import torch.nn as nn
from torch.utils.data import DataLoader
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
import matplotlib.pyplot as plt

import sys
sys.path.append("../../src")
from utils.NOAHminiDataset import *
from utils.NOAHModelUNetFiLM import *

# Helper Functions

In [None]:
def plot_gen_image(image_path, real, generated):
    # build plot
    fig, axs = plt.subplots(
        nrows = 2, 
        ncols = 1, 
        figsize=(10, 5)
    )
    
    # plot real
    axs[0].imshow(
        real.cpu().detach().numpy()[0].squeeze(),
        # vmin = -1, 
        # vmax = 1
    )
    axs[0].axis('off')
    # axs[0].set_title(f'Band: {band}')
    axs[0].set_ylabel(f'Real')
    
    
    # plot generated
    axs[1].imshow(
        generated.cpu().detach().numpy()[0].squeeze(),
        # vmin = -1, 
        # vmax = 1
    )
    axs[1].axis('off')
    # axs[1].set_title(f'Band: {band} (Generated)')
    axs[1].set_ylabel(f'Generated')

    plt.savefig(
        image_path, 
        transparent = True,
        bbox_inches = 'tight', # compress the content  
        pad_inches = 0, # have no extra margin
    )
    
    plt.show()


# Set-up Environment

In [None]:
num_channels = 5 # static rs
conditioning_dim = (24 * 6) + 4 # 24hrs for 6 features + lat, long, elev, cloud cover

learning_rate = 0.00001
batch_size = 1
num_epochs = 100

# device = 'cpu'
# device = 'cuda:0'
device = 'cuda:1'

In [None]:
dataset_dict = {
    'B1': {
        'train_dataset': NOAHminiB1Dataset(device = device),
        'test_dataset': NOAHminiB1Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB1Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B2': {
        'train_dataset': NOAHminiB2Dataset(device = device),
        'test_dataset': NOAHminiB2Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB2Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B3': {
        'train_dataset': NOAHminiB3Dataset(device = device),
        'test_dataset': NOAHminiB3Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB3Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B4': {
        'train_dataset': NOAHminiB4Dataset(device = device),
        'test_dataset': NOAHminiB4Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB4Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B5': {
        'train_dataset': NOAHminiB5Dataset(device = device),
        'test_dataset': NOAHminiB5Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB5Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B6': {
        'train_dataset': NOAHminiB6Dataset(device = device),
        'test_dataset': NOAHminiB6Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB6Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B7': {
        'train_dataset': NOAHminiB7Dataset(device = device),
        'test_dataset': NOAHminiB7Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB7Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B8': {
        'train_dataset': NOAHminiB8Dataset(device = device),
        'test_dataset': NOAHminiB8Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB8Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B9': {
        'train_dataset': NOAHminiB9Dataset(device = device),
        'test_dataset': NOAHminiB9Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB9Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B10': {
        'train_dataset': NOAHminiB10Dataset(device = device),
        'test_dataset': NOAHminiB10Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB10Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B11': {
        'train_dataset': NOAHminiB11Dataset(device = device),
        'test_dataset': NOAHminiB11Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB11Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
}

In [11]:
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
psnr = PeakSignalNoiseRatio().to(device)
mse = nn.MSELoss()

# Run Model validation

In [None]:
for band in dataset_dict:

    for channel in range(num_channels):
        channels = [c for c in range(num_channels) if c != channel]
        print(f"Working on {band} -> Channels: {channels}....")

        # get path
        model_path =  f"../../cache/saved_models/UNet_FiLM_{band}_ssim_{channels}_run1.pth"

        # init model
        model = UNetFiLM(
            in_channels = num_channels -1, # because we are removing 1 for ablation
            conditioning_dim = conditioning_dim
        ).to(device)
        # load save state dict
        model.load_state_dict(torch.load(model_path, weights_only=True))

        # get val dataset
        val_dataset = dataset_dict[band]['val_dataset']
        val_dataloader = DataLoader(
            val_dataset, 
            batch_size = batch_size,
            shuffle = True, 
            drop_last = True
        )

        # predict
        model.eval()
        
        # clear cache
        torch.cuda.empty_cache()

        # get validation results
        running_mse = []
        running_ssim = []
        running_psnr = []
        data_count = 0

        # loop over validation data
        val_batch_progress_bar = tqdm(
            val_dataloader,
            desc = "Batch"
        )

        is_image = True
        
        for gbs, static_rs, targets in val_batch_progress_bar:
            outputs = model(
                x_img = static_rs[:,channels], 
                x_weather = gbs,
            )

            if is_image:
                is_image = False
                # plot image
                plot_gen_image(
                    image_path = f"../../assets/figures/gen_image_unet_flim_{band}.png",
                    real = static_rs[0],
                    generated = outputs[0]
                )
                
            
            # get ssim
            ssim_val = ssim(targets, outputs).item()
            # get psnr
            psnr_val = psnr(targets, outputs).item()
            # mse
            mse_val = mse(targets, outputs).item()
            
            data_count += gbs.size(0)
            running_mse.append(mse_val * gbs.size(0))
            running_ssim.append(ssim_val * gbs.size(0))
            running_psnr.append(psnr_val * gbs.size(0))
            
            val_batch_progress_bar.set_postfix({
                'mse': mse_val,
                'ssim': ssim_val,
                'psnr': psnr_val
            })
        print(f"{band} Val MSE:{(sum(running_mse)/data_count):2f}  SSIM:{(sum(running_ssim)/data_count):2f}  PSNR:{(sum(running_psnr)/data_count):2f}")
        print(f"Val (std) MSE:{(np.array(running_mse).std()):2f}  SSIM:{(np.array(running_ssim).std()):2f}  PSNR:{(np.array(running_psnr).std()):2f}")
        print(f"Val MSE:{running_mse}  SSIM:{running_ssim}  PSNR:{running_psnr}")
        
    

Working on B8 -> Channels: [1, 2, 3, 4]....


Batch:   0%|          | 0/15 [00:00<?, ?it/s]

Batch: 100%|██████████| 15/15 [00:04<00:00,  3.06it/s, mse=0.958, ssim=-0.687, psnr=-7.46] 


B8 Val MSE:1.703165  SSIM:0.102453  PSNR:-0.306109
Working on B8 -> Channels: [0, 2, 3, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.37it/s, mse=0.191, ssim=0.328, psnr=6.4]    


B8 Val MSE:0.991394  SSIM:0.104520  PSNR:-0.857896
Working on B8 -> Channels: [0, 1, 3, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.36it/s, mse=1.23, ssim=0.14, psnr=1.22]    


B8 Val MSE:1.104813  SSIM:0.141327  PSNR:-1.751382
Working on B8 -> Channels: [0, 1, 2, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.28it/s, mse=0.38, ssim=0.335, psnr=4.76]   


B8 Val MSE:1.282298  SSIM:0.076486  PSNR:-0.989720
Working on B8 -> Channels: [0, 1, 2, 3]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.29it/s, mse=0.742, ssim=0.169, psnr=1.64]  


B8 Val MSE:1.198546  SSIM:0.057141  PSNR:-1.431296
Working on B9 -> Channels: [1, 2, 3, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.32it/s, mse=1.96, ssim=-0.00325, psnr=-1.91] 


B9 Val MSE:1.715343  SSIM:0.195507  PSNR:-0.964845
Working on B9 -> Channels: [0, 2, 3, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.31it/s, mse=4.26, ssim=0.000318, psnr=-6.34]


B9 Val MSE:0.928416  SSIM:0.390177  PSNR:-0.533872
Working on B9 -> Channels: [0, 1, 3, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.32it/s, mse=2.86, ssim=0.406, psnr=-17.2]  


B9 Val MSE:1.021915  SSIM:0.371390  PSNR:0.448068
Working on B9 -> Channels: [0, 1, 2, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.27it/s, mse=0.394, ssim=0.018, psnr=-19.5]  


B9 Val MSE:1.124819  SSIM:0.488403  PSNR:-2.029593
Working on B9 -> Channels: [0, 1, 2, 3]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.34it/s, mse=0.502, ssim=0.808, psnr=-3.95]  


B9 Val MSE:0.780028  SSIM:0.577123  PSNR:0.648480
Working on B10 -> Channels: [1, 2, 3, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.33it/s, mse=0.679, ssim=0.256, psnr=4.04]  


B10 Val MSE:0.876202  SSIM:0.341166  PSNR:-1.010282
Working on B10 -> Channels: [0, 2, 3, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.35it/s, mse=0.00167, ssim=0.961, psnr=23.6] 


B10 Val MSE:0.766439  SSIM:0.441515  PSNR:1.136684
Working on B10 -> Channels: [0, 1, 3, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.33it/s, mse=3.06, ssim=0.0475, psnr=-15.8] 


B10 Val MSE:0.693469  SSIM:0.452256  PSNR:2.483800
Working on B10 -> Channels: [0, 1, 2, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.24it/s, mse=2.51, ssim=0.0398, psnr=-18.3] 


B10 Val MSE:0.720294  SSIM:0.397797  PSNR:-1.585514
Working on B10 -> Channels: [0, 1, 2, 3]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.30it/s, mse=0.145, ssim=0.495, psnr=12]    


B10 Val MSE:1.565427  SSIM:0.116421  PSNR:-5.010538
Working on B11 -> Channels: [1, 2, 3, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.30it/s, mse=0.107, ssim=0.876, psnr=8.41]   


B11 Val MSE:0.682199  SSIM:0.429693  PSNR:3.119110
Working on B11 -> Channels: [0, 2, 3, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.31it/s, mse=0.233, ssim=0.615, psnr=5.24]   


B11 Val MSE:0.670348  SSIM:0.528468  PSNR:0.213288
Working on B11 -> Channels: [0, 1, 3, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.26it/s, mse=0.0542, ssim=0.633, psnr=13.1]  


B11 Val MSE:0.961516  SSIM:0.415041  PSNR:3.274360
Working on B11 -> Channels: [0, 1, 2, 4]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.33it/s, mse=0.381, ssim=0.452, psnr=6.6]  


B11 Val MSE:0.677977  SSIM:0.374280  PSNR:2.179719
Working on B11 -> Channels: [0, 1, 2, 3]....


Batch: 100%|██████████| 15/15 [00:04<00:00,  3.33it/s, mse=0.163, ssim=0.631, psnr=12.7] 

B11 Val MSE:0.625993  SSIM:0.559442  PSNR:1.856943



