In [1]:
import os
import torch
import datasets.gapfill
import options.common
import options.gan
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import yaml
from tqdm import tqdm
from pathlib import Path

In [2]:
print(os.getcwd())

/workspace/gfm-gap-filling-baseline/gap-filling-baseline


In [3]:
local_rank = 1
mean = torch.tensor([495.7316,  814.1386,  924.5740, 2962.5623, 2640.8833, 1740.3031])[None,:,None,None,None].to(local_rank)
std = torch.tensor([286.9569, 359.3304, 576.3471, 892.2656, 945.9432, 916.1625])[None,:,None,None,None].to(local_rank)

In [4]:
device = torch.device("cuda:{}".format(local_rank))
torch.cuda.set_device(device)

In [5]:
checkpoint = Path("subset_400_2023-08-21-22:19:47_uneven_bs16")
checkpoint_dir = Path("/workspace/gfm-gap-filling-baseline/data/results") / checkpoint
g_net_checkpoint = Path("/workspace/gfm-gap-filling-baseline/data/results") / checkpoint / "model_gnet_best.pt"

In [6]:
with open(checkpoint_dir / "config.yml", "r") as cfg_file:
    CONFIG = yaml.safe_load(cfg_file)
training_length = CONFIG['dataset']['training_length']
experiment_name = f'{training_length} Chips'
save_dir = checkpoint_dir

In [7]:
train_transforms, test_transforms = options.common.get_transforms(CONFIG)

val_dataset = options.common.get_dataset(CONFIG, split="validate", transforms=train_transforms)
val_chip_dataframe = pd.DataFrame(val_dataset.tif_catalog)

print(f"Number of validation images: {len(val_dataset)}")
print(f"Number of validation cloud masks: {val_dataset.n_cloudpaths}")

Number of validation images: 1621
Number of validation cloud masks: 1600


In [8]:
g_net = options.gan.get_generator(CONFIG).to(device)

g_net_state_dict = torch.load(g_net_checkpoint)

g_net.load_state_dict(g_net_state_dict)

<All keys matched successfully>

In [9]:
# Setting up the dataset sampler the same as during training
val_sampler = torch.utils.data.SequentialSampler(val_dataset)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=16,
    sampler=val_sampler)
print(len(val_dataloader))

102


In [10]:
for i, data in tqdm(enumerate(val_dataloader), initial = 1, total=102):
    if i == 40:  # Batches are zero-indexed
        sample = data
    elif i == 55:
        fullcoveragesample = data
sample = {k: v.to(device) for k, v in sample.items()}
fullcoveragesample = {k: v.to(device) for k, v in fullcoveragesample.items()}
g_input = sample["masked"]
dest_fake = g_net(g_input)
gen_unmasked = dest_fake * sample["cloud"]

103it [00:35,  2.85it/s]                                                  


In [11]:
non_cloud = sample["masked"].detach().cpu().numpy()
predicted_masked = gen_unmasked.detach().cpu().numpy()
input_masked = sample["unmasked"].detach().cpu().numpy()

non_cloud[non_cloud == 0] = np.nan
predicted_masked[predicted_masked == 0] = np.nan
input_masked[input_masked == 0] = np.nan

In [12]:
non_cloud.shape

(16, 18, 224, 224)

In [13]:
### creating a pairgrid for the low coverage image

# putting the data into a dataframe where each column represents an ordered list of band values
non_cloud_data = pd.DataFrame({
    'B2': non_cloud[0,6,:,:].flatten(),
    'B3': non_cloud[0,7,:,:].flatten(),
    'B4': non_cloud[0,8,:,:].flatten(),
    'B5': non_cloud[0,9,:,:].flatten(),
    'B7': non_cloud[0,10,:,:].flatten(),
    'B8': non_cloud[0,11,:,:].flatten()
})
gen_data = pd.DataFrame({
    'B2': predicted_masked[0,6,:,:].flatten(),
    'B3': predicted_masked[0,7,:,:].flatten(),
    'B4': predicted_masked[0,8,:,:].flatten(),
    'B5': predicted_masked[0,9,:,:].flatten(),
    'B7': predicted_masked[0,10,:,:].flatten(),
    'B8': predicted_masked[0,11,:,:].flatten()
})
true_data = pd.DataFrame({
    'B2': input_masked[0,6,:,:].flatten(),
    'B3': input_masked[0,7,:,:].flatten(),
    'B4': input_masked[0,8,:,:].flatten(),
    'B5': input_masked[0,9,:,:].flatten(),
    'B7': input_masked[0,10,:,:].flatten(),
    'B8': input_masked[0,11,:,:].flatten()
})
bin_edges = [round(i * 0.025, 3) for i in range(40)]

true_data_pairgrid = sns.PairGrid(true_data, diag_sharey=False)
true_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
true_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Ground Truth Pixels\nCGAN, {experiment_name}, Low Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_ground_truth_low_coverage.png', format='png')
plt.close()

gen_data_pairgrid = sns.PairGrid(gen_data, diag_sharey=False)
gen_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
gen_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Generated Pixels\nCGAN, {experiment_name}, Low Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_generated_low_coverage.png', format='png')
plt.close()

non_cloud_pairgrid = sns.PairGrid(non_cloud_data, diag_sharey=False)
non_cloud_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
non_cloud_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Non-Cloud Pixels\nCGAN, {experiment_name}, Low Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_non_cloud_low_coverage.png', format='png')
plt.close()

In [14]:
g_input = fullcoveragesample["masked"]
dest_fake = g_net(g_input)
gen_unmasked = dest_fake * fullcoveragesample["cloud"]

non_cloud = fullcoveragesample["masked"].detach().cpu().numpy()
predicted_masked = gen_unmasked.detach().cpu().numpy()
input_masked = fullcoveragesample["unmasked"].detach().cpu().numpy()

non_cloud[non_cloud == 0] = np.nan
predicted_masked[predicted_masked == 0] = np.nan
input_masked[input_masked == 0] = np.nan

In [15]:
### creating a pairgrid for the full coverage image

# putting the data into a dataframe where each column represents an ordered list of band values
non_cloud_data = pd.DataFrame({
    'B2': non_cloud[0,6,:,:].flatten(),
    'B3': non_cloud[0,7,:,:].flatten(),
    'B4': non_cloud[0,8,:,:].flatten(),
    'B5': non_cloud[0,9,:,:].flatten(),
    'B7': non_cloud[0,10,:,:].flatten(),
    'B8': non_cloud[0,11,:,:].flatten()
})
gen_data = pd.DataFrame({
    'B2': predicted_masked[0,6,:,:].flatten(),
    'B3': predicted_masked[0,7,:,:].flatten(),
    'B4': predicted_masked[0,8,:,:].flatten(),
    'B5': predicted_masked[0,9,:,:].flatten(),
    'B7': predicted_masked[0,10,:,:].flatten(),
    'B8': predicted_masked[0,11,:,:].flatten()
})
true_data = pd.DataFrame({
    'B2': input_masked[0,6,:,:].flatten(),
    'B3': input_masked[0,7,:,:].flatten(),
    'B4': input_masked[0,8,:,:].flatten(),
    'B5': input_masked[0,9,:,:].flatten(),
    'B7': input_masked[0,10,:,:].flatten(),
    'B8': input_masked[0,11,:,:].flatten()
})
bin_edges = [round(i * 0.025, 3) for i in range(40)]

true_data_pairgrid = sns.PairGrid(true_data, diag_sharey=False)
true_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
true_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Ground Truth Pixels\nCGAN, {experiment_name}, Full Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_ground_truth_full_coverage.png', format='png')
plt.close()

gen_data_pairgrid = sns.PairGrid(gen_data, diag_sharey=False)
gen_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
gen_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Generated Pixels\nCGAN, {experiment_name}, Full Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_generated_full_coverage.png', format='png')
plt.close()

non_cloud_pairgrid = sns.PairGrid(non_cloud_data, diag_sharey=False)
non_cloud_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
non_cloud_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Non-Cloud Pixels\nCGAN, {experiment_name}, Full Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_non_cloud_full_coverage.png', format='png')
plt.close()

In [16]:
# Adding pixel values to tensors iteratively
true_pixels = torch.empty((6, 0)).to(local_rank)
gen_pixels = torch.empty((6, 0)).to(local_rank)
non_cloud_pixels = torch.empty((6, 0)).to(local_rank)

g_net.eval()

with torch.no_grad():
    for idx, sample in tqdm(enumerate(val_dataloader), initial = 1, total=12):
        sample = {k: v.to(device) for k, v in sample.items()}
        g_input = sample["masked"]
        dest_fake = g_net(g_input)
        gen_unmasked = dest_fake * sample["cloud"]
    
        non_cloud_pixels_data = sample["masked"][:,6:12,:,:].reshape(6,-1)
        gen_pixels_data = gen_unmasked[:,6:12,:,:].reshape(6,-1)
        true_pixels_data = sample["unmasked"][:,6:12,:,:].reshape(6,-1)

        true_pixels = torch.cat((true_pixels, true_pixels_data), dim=1)
        gen_pixels = torch.cat((gen_pixels, gen_pixels_data), dim=1)
        non_cloud_pixels = torch.cat((non_cloud_pixels, non_cloud_pixels_data), dim=1) 
        
        if idx + 1 == 12:
            true_pixels = true_pixels.cpu().numpy()
            gen_pixels = gen_pixels.cpu().numpy()
            non_cloud_pixels = non_cloud_pixels.cpu().numpy()
            true_pixels[true_pixels == 0] = np.nan
            gen_pixels[gen_pixels == 0] = np.nan
            non_cloud_pixels[non_cloud_pixels == 0] = np.nan
            break

100%|█████████████████████████████████████| 12/12 [00:05<00:00,  1.86it/s]


In [17]:
non_cloud_pixels.shape

(6, 9633792)

In [18]:
# Putting these values in a dataframe

non_cloud_data = pd.DataFrame({
    'B2': non_cloud_pixels[0],
    'B3': non_cloud_pixels[1],
    'B4': non_cloud_pixels[2],
    'B5': non_cloud_pixels[3],
    'B7': non_cloud_pixels[4],
    'B8': non_cloud_pixels[5]
})
gen_data = pd.DataFrame({
    'B2': gen_pixels[0],
    'B3': gen_pixels[1],
    'B4': gen_pixels[2],
    'B5': gen_pixels[3],
    'B7': gen_pixels[4],
    'B8': gen_pixels[5]
})
true_data = pd.DataFrame({
    'B2': true_pixels[0],
    'B3': true_pixels[1],
    'B4': true_pixels[2],
    'B5': true_pixels[3],
    'B7': true_pixels[4],
    'B8': true_pixels[5]
})

In [19]:
bin_edges = [round(i * 0.025, 3) for i in range(40)]

true_data_pairgrid = sns.PairGrid(true_data, diag_sharey=False)
true_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
true_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Ground Truth Pixels\nCGAN, {experiment_name}, 192 Test Images', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_ground_truth_all.png', format='png')
plt.close()

gen_data_pairgrid = sns.PairGrid(gen_data, diag_sharey=False)
gen_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
gen_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Generated Pixels\nCGAN, {experiment_name}, 192 Test Images', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_generated_all.png', format='png')
plt.close()

non_cloud_pairgrid = sns.PairGrid(non_cloud_data, diag_sharey=False)
non_cloud_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
non_cloud_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Non-Cloud Pixels\nCGAN, {experiment_name}, 192 Test Images', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_non_cloud_all.png', format='png')
plt.close()