In [None]:
# ! pip install scikit-image
# ! pip install torch torchmetrics zarr torchvision
# ! nvidia-smi

# Import Packages

In [None]:
import torch
from tqdm import tqdm 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
import matplotlib.pyplot as plt

import sys
sys.path.append("../../src")
from utils.NOAHminiDataset import *
from utils.NOAHModelUNetFiLM import *

# Helper Functions

In [None]:
def plot_gen_image(image_path, real, generated):
    # build plot
    fig, axs = plt.subplots(
        nrows = 2, 
        ncols = 1, 
        figsize=(10, 5)
    )
    
    # plot real
    axs[0].imshow(
        real.cpu().detach().numpy()[0].squeeze(),
        # vmin = -1, 
        # vmax = 1
    )
    axs[0].axis('off')
    # axs[0].set_title(f'Band: {band}')
    axs[0].set_ylabel(f'Real')
    
    
    # plot generated
    axs[1].imshow(
        generated.cpu().detach().numpy()[0].squeeze(),
        # vmin = -1, 
        # vmax = 1
    )
    axs[1].axis('off')
    # axs[1].set_title(f'Band: {band} (Generated)')
    axs[1].set_ylabel(f'Generated')

    plt.savefig(
        image_path, 
        transparent = True,
        bbox_inches = 'tight', # compress the content  
        pad_inches = 0, # have no extra margin
    )
    
    plt.show()


# Set-up Environment

In [None]:
num_channels = 5 # static rs
conditioning_dim = (24 * 6) + 4 # 24hrs for 6 features + lat, long, elev, cloud cover

learning_rate = 0.0001
batch_size = 1
num_epochs = 100

# device = 'cpu'
# device = 'cuda:0'
device = 'cuda:1'

In [None]:
dataset_dict = {
    'B1': {
        'train_dataset': NOAHminiB1Dataset(device = device),
        'test_dataset': NOAHminiB1Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB1Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B2': {
        'train_dataset': NOAHminiB2Dataset(device = device),
        'test_dataset': NOAHminiB2Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB2Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B3': {
        'train_dataset': NOAHminiB3Dataset(device = device),
        'test_dataset': NOAHminiB3Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB3Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B4': {
        'train_dataset': NOAHminiB4Dataset(device = device),
        'test_dataset': NOAHminiB4Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB4Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B5': {
        'train_dataset': NOAHminiB5Dataset(device = device),
        'test_dataset': NOAHminiB5Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB5Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B6': {
        'train_dataset': NOAHminiB6Dataset(device = device),
        'test_dataset': NOAHminiB6Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB6Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B7': {
        'train_dataset': NOAHminiB7Dataset(device = device),
        'test_dataset': NOAHminiB7Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB7Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B8': {
        'train_dataset': NOAHminiB8Dataset(device = device),
        'test_dataset': NOAHminiB8Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB8Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B9': {
        'train_dataset': NOAHminiB9Dataset(device = device),
        'test_dataset': NOAHminiB9Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB9Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B10': {
        'train_dataset': NOAHminiB10Dataset(device = device),
        'test_dataset': NOAHminiB10Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB10Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B11': {
        'train_dataset': NOAHminiB11Dataset(device = device),
        'test_dataset': NOAHminiB11Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB11Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
}

In [None]:
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
psnr = PeakSignalNoiseRatio().to(device)
mse = nn.MSELoss()

# Run Model validation

In [None]:
for band in dataset_dict:
    # get path
    model_path = f"../../cache/saved_models/UNet_FiLM_{band}_ssim_all_e10.pth"
    # init model
    model = UNetFiLM(
        in_channels = num_channels, 
        conditioning_dim = conditioning_dim
    ).to(device)
    # load save state dict
    model.load_state_dict(torch.load(model_path, weights_only=True))

    # get val dataset
    val_dataset = dataset_dict[band]['val_dataset']
    val_dataloader = DataLoader(
        val_dataset, 
        batch_size = batch_size,
        shuffle = True, 
        drop_last = True
    )

    # predict
    model.eval()
    
    # clear cache
    torch.cuda.empty_cache()

    # get validation results
    running_mse = []
    running_ssim = []
    running_psnr = []
    data_count = 0

    # loop over validation data
    val_batch_progress_bar = tqdm(
        val_dataloader,
        desc = "Batch"
    )

    is_image = True
    
    for gbs, static_rs, targets in val_batch_progress_bar:
        outputs = model(
            x_img = static_rs, 
            x_weather = gbs,
        )

        if is_image:
            is_image = False
            # plot image
            plot_gen_image(
                image_path = f"../../assets/figures/gen_image_unet_flim_{band}.png",
                real = static_rs[0],
                generated = outputs[0]
            )
            
        
        # get ssim
        ssim_val = ssim(targets, outputs).item()
        # get psnr
        psnr_val = psnr(targets, outputs).item()
        # mse
        mse_val = mse(targets, outputs).item()
        
        data_count += gbs.size(0)
        running_mse.append(mse_val * gbs.size(0))
        running_ssim.append(ssim_val * gbs.size(0))
        running_psnr.append(psnr_val * gbs.size(0))
        
        val_batch_progress_bar.set_postfix({
            'mse': mse_val,
            'ssim': ssim_val,
            'psnr': psnr_val
        })
    print(f"{band} Val MSE:{(sum(running_mse)/data_count):2f}  SSIM:{(sum(running_ssim)/data_count):2f}  PSNR:{(sum(running_psnr)/data_count):2f}")
    print(f"Val (std) MSE:{(np.array(running_mse).std()):2f}  SSIM:{(np.array(running_ssim).std()):2f}  PSNR:{(np.array(running_psnr).std()):2f}")
    print(f"Val MSE:{running_mse}  SSIM:{running_ssim}  PSNR:{running_psnr}")
    
    