In [32]:
import os
import torch
import datasets.gapfill
import options.common
import options.gan
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import yaml
from tqdm import tqdm
from pathlib import Path

In [33]:
# set local rank and define mean and std tensors for normalization purposes
local_rank = 1
mean = torch.tensor([495.7316,  814.1386,  924.5740, 2962.5623, 2640.8833, 1740.3031])[None,:,None,None,None].to(local_rank)
std = torch.tensor([286.9569, 359.3304, 576.3471, 892.2656, 945.9432, 916.1625])[None,:,None,None,None].to(local_rank)

In [34]:
device = torch.device("cuda:{}".format(local_rank))
torch.cuda.set_device(device)

In [35]:
# replace with the job ID of the experiment you want to visualize
job_id = "subset_6231_2024-01-30-19:45:13"

In [36]:
checkpoint = Path(job_id)
checkpoint_dir = Path("/workspace/gfm-gap-filling-baseline/data/results") / checkpoint
g_net_checkpoint = Path("/workspace/gfm-gap-filling-baseline/data/results") / checkpoint / "model_gnet_best.pt"

In [37]:
# load the YAML file
with open(checkpoint_dir / "config.yml", "r") as cfg_file:
    CONFIG = yaml.safe_load(cfg_file)
training_length = CONFIG['dataset']['training_length']
experiment_name = f'{training_length} Chips'
save_dir = checkpoint_dir

In [38]:
train_transforms, test_transforms = options.common.get_transforms(CONFIG)

# define validation dataset
val_dataset = options.common.get_dataset(CONFIG, split="validate", transforms=train_transforms)
val_chip_dataframe = pd.DataFrame(val_dataset.tif_catalog)

# ensure the length and number of masks are correct
print(f"Number of validation images: {len(val_dataset)}")
print(f"Number of validation cloud masks: {val_dataset.n_cloudpaths}")

Number of validation images: 1621
Number of validation cloud masks: 1600


In [39]:
# prepare model
g_net = options.gan.get_generator(CONFIG).to(device)

g_net_state_dict = torch.load(g_net_checkpoint)

g_net.load_state_dict(g_net_state_dict)

<All keys matched successfully>

In [40]:
# setting up the dataset sampler the same as during training
# it is very important that the batch size is identical to the batch size used during training
val_sampler = torch.utils.data.SequentialSampler(val_dataset)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=16,
    sampler=val_sampler)
print(len(val_dataloader))

102


In [41]:
# get low coverage sample and full coverage sample by iterating over dataset
for i, data in tqdm(enumerate(val_dataloader), initial = 1, total=102):
    if i == 5:  # Batches are zero-indexed
        sample = data
    elif i == 80:
        fullcoveragesample = data

# send both samples to the cuda device
sample = {k: v.to(device) for k, v in sample.items()}
fullcoveragesample = {k: v.to(device) for k, v in fullcoveragesample.items()}

# get input, run model, and get ground truth
g_input = sample["masked"]
dest_fake = g_net(g_input)
gen_unmasked = dest_fake * sample["cloud"]

103it [00:26,  3.86it/s]                                                                                                


In [18]:
# send tensors to numpy
non_cloud = sample["masked"].detach().cpu().numpy()
predicted_masked = gen_unmasked.detach().cpu().numpy()
input_masked = sample["unmasked"].detach().cpu().numpy()

# set values that are masked out to nan so the are not counted in visualizations
non_cloud[non_cloud == 0] = np.nan
predicted_masked[predicted_masked == 0] = np.nan
input_masked[input_masked == 0] = np.nan

In [20]:
### creating a pairgrid for the low coverage image

# putting the data into a dataframe where each column represents an ordered list of band values
# the tensors are in format (Batch, Channel, Time Step, Height, Width)
# therefore, we select the first batch of 1, each channel sequentially, the second time step, and all pixels within H and W
non_cloud_data = pd.DataFrame({
    'B2': non_cloud[0,[0,6,12],:,:].flatten(),
    'B3': non_cloud[0,[1,7,13],:,:].flatten(),
    'B4': non_cloud[0,[2,8,14],:,:].flatten(),
    'B5': non_cloud[0,[3,9,15],:,:].flatten(),
    'B7': non_cloud[0,[4,10,16],:,:].flatten(),
    'B8': non_cloud[0,[5,11,17],:,:].flatten()
})
gen_data = pd.DataFrame({
    'B2': predicted_masked[0,[0,6,12],:,:].flatten(),
    'B3': predicted_masked[0,[1,7,13],:,:].flatten(),
    'B4': predicted_masked[0,[2,8,14],:,:].flatten(),
    'B5': predicted_masked[0,[3,9,15],:,:].flatten(),
    'B7': predicted_masked[0,[4,10,16],:,:].flatten(),
    'B8': predicted_masked[0,[5,11,17],:,:].flatten()
})
true_data = pd.DataFrame({
    'B2': input_masked[0,[0,6,12],:,:].flatten(),
    'B3': input_masked[0,[1,7,13],:,:].flatten(),
    'B4': input_masked[0,[2,8,14],:,:].flatten(),
    'B5': input_masked[0,[3,9,15],:,:].flatten(),
    'B7': input_masked[0,[4,10,16],:,:].flatten(),
    'B8': input_masked[0,[5,11,17],:,:].flatten()
})

# define 40 regular bin edges from 0 to 1 in increments of 0.025
bin_edges = [round(i * 0.025, 3) for i in range(40)]

# first pairgrid: true data, low coverage
true_data_pairgrid = sns.PairGrid(true_data, diag_sharey=False)
true_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
true_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Ground Truth Pixels\nCGAN, {experiment_name}, Low Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_ground_truth_low_coverage.png', format='png')
plt.close()

# second pairgrid: generated data, low coverage
gen_data_pairgrid = sns.PairGrid(gen_data, diag_sharey=False)
gen_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
gen_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Generated Pixels\nCGAN, {experiment_name}, Low Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_generated_low_coverage.png', format='png')
plt.close()

# third pairgrid: non-masked data, low coverage
non_cloud_pairgrid = sns.PairGrid(non_cloud_data, diag_sharey=False)
non_cloud_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
non_cloud_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Non-Cloud Pixels\nCGAN, {experiment_name}, Low Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_non_cloud_low_coverage.png', format='png')
plt.close()

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


In [42]:
# get input, run model, and get ground truth for full coverage sample
g_input = fullcoveragesample["masked"]
dest_fake = g_net(g_input)
gen_unmasked = dest_fake * fullcoveragesample["cloud"]

# send tensors to numpy
non_cloud = fullcoveragesample["masked"].detach().cpu().numpy()
predicted_masked = gen_unmasked.detach().cpu().numpy()
input_masked = fullcoveragesample["unmasked"].detach().cpu().numpy()

# set values that are masked out to nan so the are not counted in visualizations
non_cloud[non_cloud == 0] = np.nan
predicted_masked[predicted_masked == 0] = np.nan
input_masked[input_masked == 0] = np.nan

In [43]:
### creating a pairgrid for the full coverage image

# putting the data into a dataframe where each column represents an ordered list of band values
# the tensors are in format (Batch, Channel, Time Step, Height, Width)
# therefore, we select the first batch of 1, each channel sequentially, the second time step, and all pixels within H and W
non_cloud_data = pd.DataFrame({
    'B2': non_cloud[0,[0,6,12],:,:].flatten(),
    'B3': non_cloud[0,[1,7,13],:,:].flatten(),
    'B4': non_cloud[0,[2,8,14],:,:].flatten(),
    'B5': non_cloud[0,[3,9,15],:,:].flatten(),
    'B7': non_cloud[0,[4,10,16],:,:].flatten(),
    'B8': non_cloud[0,[5,11,17],:,:].flatten()
})
gen_data = pd.DataFrame({
    'B2': predicted_masked[0,[0,6,12],:,:].flatten(),
    'B3': predicted_masked[0,[1,7,13],:,:].flatten(),
    'B4': predicted_masked[0,[2,8,14],:,:].flatten(),
    'B5': predicted_masked[0,[3,9,15],:,:].flatten(),
    'B7': predicted_masked[0,[4,10,16],:,:].flatten(),
    'B8': predicted_masked[0,[5,11,17],:,:].flatten()
})
true_data = pd.DataFrame({
    'B2': input_masked[0,[0,6,12],:,:].flatten(),
    'B3': input_masked[0,[1,7,13],:,:].flatten(),
    'B4': input_masked[0,[2,8,14],:,:].flatten(),
    'B5': input_masked[0,[3,9,15],:,:].flatten(),
    'B7': input_masked[0,[4,10,16],:,:].flatten(),
    'B8': input_masked[0,[5,11,17],:,:].flatten()
})

# define 40 regular bin edges from 0 to 1 in increments of 0.025
bin_edges = [round(i * 0.025, 3) for i in range(40)]

# first pairgrid: true data, high coverage
true_data_pairgrid = sns.PairGrid(true_data, diag_sharey=False)
true_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
true_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Ground Truth Pixels\nCGAN, {experiment_name}, Full Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_ground_truth_full_coverage.png', format='png')
plt.close()

# second pairgrid: generated data, high coverage
gen_data_pairgrid = sns.PairGrid(gen_data, diag_sharey=False)
gen_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
gen_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Generated Pixels\nCGAN, {experiment_name}, Full Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_generated_full_coverage.png', format='png')
plt.close()

# third pairgrid: non-masked data, low coverage
non_cloud_pairgrid = sns.PairGrid(non_cloud_data, diag_sharey=False)
non_cloud_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
non_cloud_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Non-Cloud Pixels\nCGAN, {experiment_name}, Full Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_non_cloud_full_coverage.png', format='png')
plt.close()

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


In [54]:
# Adding pixel values to tensors iteratively

# initialize empty tensors to which we will concatenate pixel values for each band
# as we have 6 bands, we initialize empty tensors of size (6, 0)
# the tensors will be (Channel, B*H*W)
true_pixels = torch.empty((6, 0)).to(local_rank)
gen_pixels = torch.empty((6, 0)).to(local_rank)
non_cloud_pixels = torch.empty((6, 0)).to(local_rank)

g_net.eval()

with torch.no_grad():
    for idx, sample in tqdm(enumerate(val_dataloader), initial = 1, total=4):
        # send sample to device
        sample = {k: v.to(device) for k, v in sample.items()}
        
        # get input, run model, get ground truth
        g_input = sample["masked"]
        dest_fake = g_net(g_input)
        gen_unmasked = dest_fake * sample["cloud"]
        
        # reshape (B, C, H, W) to (6, B*H*W), where we take the 6 bands representing the middle time step
        non_cloud_pixels_data = sample["masked"][:,:6,:,:].permute(1, 0, 2, 3).contiguous().view(6,-1)
        gen_pixels_data = gen_unmasked[:,:6,:,:].permute(1, 0, 2, 3).contiguous().view(6,-1)
        true_pixels_data = sample["unmasked"][:,:6,:,:].permute(1, 0, 2, 3).contiguous().view(6,-1)

        non_cloud_pixels_data2 = sample["masked"][:,6:12,:,:].permute(1, 0, 2, 3).contiguous().view(6,-1)
        gen_pixels_data2 = gen_unmasked[:,6:12,:,:].permute(1, 0, 2, 3).contiguous().view(6,-1)
        true_pixels_data2 = sample["unmasked"][:,6:12,:,:].permute(1, 0, 2, 3).contiguous().view(6,-1)

        non_cloud_pixels_data3 = sample["masked"][:,12:,:,:].permute(1, 0, 2, 3).contiguous().view(6,-1)
        gen_pixels_data3 = gen_unmasked[:,12:,:,:].permute(1, 0, 2, 3).contiguous().view(6,-1)
        true_pixels_data3 = sample["unmasked"][:,12:,:,:].permute(1, 0, 2, 3).contiguous().view(6,-1)

        # concatenate all pixel values from middle time step for each band to running tensors
        true_pixels = torch.cat((true_pixels, true_pixels_data, true_pixels_data2, true_pixels_data3), dim=1)
        gen_pixels = torch.cat((gen_pixels, gen_pixels_data, gen_pixels_data2, gen_pixels_data3), dim=1)
        non_cloud_pixels = torch.cat((non_cloud_pixels, non_cloud_pixels_data, non_cloud_pixels_data2, non_cloud_pixels_data3), dim=1) 
        
        # at the last iteration, send tensors to numpy and set 0 values to np.nan
        if idx + 1 == 4:
            true_pixels = true_pixels.cpu().numpy()
            gen_pixels = gen_pixels.cpu().numpy()
            non_cloud_pixels = non_cloud_pixels.cpu().numpy()
            true_pixels[true_pixels == 0] = np.nan
            gen_pixels[gen_pixels == 0] = np.nan
            non_cloud_pixels[non_cloud_pixels == 0] = np.nan
            break

100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  1.53it/s]


In [55]:
# putting these values in a dataframe

non_cloud_data = pd.DataFrame({
    'B2': non_cloud_pixels[0],
    'B3': non_cloud_pixels[1],
    'B4': non_cloud_pixels[2],
    'B5': non_cloud_pixels[3],
    'B7': non_cloud_pixels[4],
    'B8': non_cloud_pixels[5]
})
gen_data = pd.DataFrame({
    'B2': gen_pixels[0],
    'B3': gen_pixels[1],
    'B4': gen_pixels[2],
    'B5': gen_pixels[3],
    'B7': gen_pixels[4],
    'B8': gen_pixels[5]
})
true_data = pd.DataFrame({
    'B2': true_pixels[0],
    'B3': true_pixels[1],
    'B4': true_pixels[2],
    'B5': true_pixels[3],
    'B7': true_pixels[4],
    'B8': true_pixels[5]
})

In [56]:
# define 40 regular bin edges from 0 to 1 in increments of 0.025
bin_edges = [round(i * 0.025, 3) for i in range(40)]

# first pairgrid: true data, all pixels in 192 images
true_data_pairgrid = sns.PairGrid(true_data, diag_sharey=False)
true_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
true_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Ground Truth Pixels\nCGAN, {experiment_name}, 64 Test Images', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_ground_truth_all.png', format='png')
plt.close()

# second pairgrid: generated data, all pixels in 192 images
gen_data_pairgrid = sns.PairGrid(gen_data, diag_sharey=False)
gen_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
gen_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Generated Pixels\nCGAN, {experiment_name}, 64 Test Images', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_generated_all.png', format='png')
plt.close()

# third pairgrid: non cloud data, all pixels in 192 images
non_cloud_pairgrid = sns.PairGrid(non_cloud_data, diag_sharey=False)
non_cloud_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
non_cloud_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Non-Cloud Pixels\nCGAN, {experiment_name}, 64 Test Images', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_non_cloud_all.png', format='png')
plt.close()

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
