In [1]:
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch
import h5py
import os

from inference import *
from modules import *
from config import *

In [2]:
# Initialize a Model object to load the pretrained model
folderpath_clf = "/home/samet/et-estimation/src/checkpoint/2023-09-26_09-24-10/"
model = Model(folderpath_clf=folderpath_clf, clf_idx=49)

In [3]:
year = 2021
batch_size = 1
split = "test"

# Initialize dataset
dset = model.init_dset(year, split)
# Initialize iterator
iterator = model.init_iterator(dset)

iterator = DataLoader(
    dset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=False,
    pin_memory=True,
    )

In [4]:
model.clf.eval()  # Set the model to evaluation mode
print()




In [5]:
n_example = 200
ct = 1
# Iterate over the iterator
for (
    s1,
    s2,
    et,
    weather,
    dem,
    sm,
    soilgrids,
    cropland_mask,
) in iterator:
    # Carry tensors to the specified device (cpu or gpu)
    (
        s1,
        s2,
        et,
        weather,
        dem,
        sm,
        soilgrids,
        cropland_mask,
    ) = model._to_device(
        s1,
        s2,
        et,
        weather,
        dem,
        sm,
        soilgrids,
        cropland_mask,
    )

    R = torch.sum(cropland_mask).item() / (128**2)
    if R < 0.7:
        continue

    with torch.cuda.amp.autocast():
        y_ = model._predict_step(
            s1,
            s2,
            weather,
            dem,
            sm,
            soilgrids,
        ).float()
        
        target_folderpath_s1_tensor = "/home/samet/for-inference/data/tensors/s1"
        target_folderpath_era5_tensor = "/home/samet/for-inference/data/tensors/era5"
        target_folderpath_dem_tensor = "/home/samet/for-inference/data/tensors/dem"
        target_folderpath_et_gt_tensor = "/home/samet/for-inference/data/tensors/et-gt"
        target_folderpath_et_pred_tensor = "/home/samet/for-inference/data/tensors/et-pred"
        target_folderpath_cropland_mask_tensor = "/home/samet/for-inference/data/tensors/cropland-mask"

        os.makedirs(target_folderpath_s1_tensor, exist_ok=True)
        os.makedirs(target_folderpath_era5_tensor, exist_ok=True)
        os.makedirs(target_folderpath_dem_tensor, exist_ok=True)
        os.makedirs(target_folderpath_et_gt_tensor, exist_ok=True)
        os.makedirs(target_folderpath_et_pred_tensor, exist_ok=True)
        os.makedirs(target_folderpath_cropland_mask_tensor, exist_ok=True)

        torch.save(s1, f"{target_folderpath_s1_tensor}/s1_{ct:04}.pt")
        torch.save(weather, f"{target_folderpath_era5_tensor}/era5_{ct:04}.pt")
        torch.save(dem, f"{target_folderpath_dem_tensor}/dem_{ct:04}.pt")
        torch.save(et, f"{target_folderpath_et_gt_tensor}/et_gt_{ct:04}.pt")
        torch.save(y_, f"{target_folderpath_et_pred_tensor}/et_pred_{ct:04}.pt")
        torch.save(cropland_mask, f"{target_folderpath_cropland_mask_tensor}/cropland_mask_{ct:04}.pt")  

    # Convert torch.tensor to numpy.array
    s2_img = s2.cpu().numpy().transpose(0, 2, 3, 1).squeeze()[:,:,2::-1]
    y_pred = y_.cpu().numpy().transpose(0, 2, 3, 1).squeeze()
    et_img = et.cpu().numpy().transpose(0, 2, 3, 1).squeeze()
    cropland_mask_img = cropland_mask.cpu().numpy().squeeze()

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    axes[0].imshow(s2_img)
    axes[1].imshow(et_img, cmap='coolwarm', vmin=0, vmax=10)
    axes[2].imshow(y_pred, cmap='coolwarm', vmin=0, vmax=10)
    axes[3].imshow(cropland_mask_img, cmap='RdYlGn', vmin=0, vmax=1, interpolation='none')
    #plt.show()
    plt.close()
    
    target_folderpath_rgb = "/home/samet/for-inference/data/figs/rgb"
    target_folderpath_et_gt = "/home/samet/for-inference/data/figs/et-gt"
    target_folderpath_et_pred = "/home/samet/for-inference/data/figs/et-pred"
    target_folderpath_cropland_mask = "/home/samet/for-inference/data/figs/cropland-mask"

    os.makedirs(target_folderpath_rgb, exist_ok=True)
    os.makedirs(target_folderpath_et_gt, exist_ok=True)
    os.makedirs(target_folderpath_et_pred, exist_ok=True)
    os.makedirs(target_folderpath_cropland_mask, exist_ok=True)

    plt.imshow(s2_img)
    plt.axis('off')
    plt.savefig(f"{target_folderpath_rgb}/rgb_{ct:04}.png", bbox_inches='tight', dpi=300)
    plt.close()

    plt.imshow(et_img, cmap='coolwarm', vmin=0, vmax=8)
    plt.axis('off')
    plt.savefig(f"{target_folderpath_et_gt}/et_gt_{ct:04}.png", bbox_inches='tight', dpi=300)
    plt.close()

    plt.imshow(y_pred, cmap='coolwarm', vmin=0, vmax=8)
    plt.axis('off')
    plt.savefig(f"{target_folderpath_et_pred}/et_pred_{ct:04}.png", bbox_inches='tight', dpi=300)
    plt.close()

    plt.imshow(cropland_mask_img, cmap='RdYlGn', vmin=0, vmax=1, interpolation='none')
    plt.axis('off')
    plt.savefig(f"{target_folderpath_cropland_mask}/cropland_mask_{ct:04}.png", bbox_inches='tight', dpi=300)
    plt.close()


    if ct >= n_example:
        break
    ct += 1

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
