In [17]:
import torch
import mae.models_mae
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import yaml
from mae_training import CombinedDataset

In [18]:
# set local rank and define mean and std tensors for normalization purposes
local_rank = 0
mean = torch.tensor([495.7316,  814.1386,  924.5740, 2962.5623, 2640.8833, 1740.3031])[None,:,None,None,None].to(local_rank)
std = torch.tensor([286.9569, 359.3304, 576.3471, 892.2656, 945.9432, 916.1625])[None,:,None,None,None].to(local_rank)

In [19]:
# define the function that prepares the model based on the checkpoint and the correct architecture
def prepare_model(checkpoint, arch='mae_vit_base_patch16'):
        # build model
        model = getattr(mae.models_mae, arch)()
        # load model
        checkpoint_file = torch.load(checkpoint, map_location=f'cuda:{local_rank}')
        msg = model.load_state_dict(checkpoint_file, strict=False)
        print(msg)
        return model

In [20]:
# replace with the job ID of the experiment you want to visualize
job_id = "6231-fair-bs16-2024-01-25_22-04-34"

In [21]:
# load the YAML file
yaml_file_path = Path(f"/workspace/data/lchu/hls/jobs/{job_id}.yaml")
with open(yaml_file_path, "r") as file:
    yaml_data = yaml.safe_load(file)
training_length = yaml_data["training_length"]
save_dir = Path("/workspace/data/lchu/hls/vis/full_6231")
checkpoint = Path(yaml_data["checkpoint_dir"]) / "model_best_ssim.pt"
experiment_name = f"{training_length} Chips"

In [22]:
# # define which checkpoint, name of experiment, and save directory we will use for zero shot
# experiment_name = "Zero-Shot"
# checkpoint = Path("/workspace/gfm-gap-filling/pretraining/epoch-832-loss-0.0473.pt")
# save_dir = Path("/workspace/data/lchu/hls/vis/zero_shot_visualization")

In [23]:
# prepare model
model = prepare_model(checkpoint, 'mae_vit_base_patch16')
print('Model loaded.')

<All keys matched successfully>
Model loaded.


In [24]:
# define validation dataset
val_dataset = CombinedDataset("/workspace/gfm-gap-filling/pretraining/training_data", split="validate", num_frames=3, img_size=224, bands=6, cloud_range=[0.01,1.0],
                              # random_cropping=random_cropping, remove_cloud=True, 
                               normalize=True)

In [25]:
# ensure the length and number of masks are correct
print(f"--> Validation set len = {len(val_dataset)}")
print(f"--> Validation set masks = {val_dataset.n_cloudpaths}")

--> Validation set len = 1621
--> Validation set masks = 1600


In [26]:
# send model to device at local rank
torch.cuda.set_device(local_rank)
model = model.to(torch.cuda.current_device())

In [27]:
# get the low coverage image from the val dataset and ensure the shape is correct
first_batch = torch.from_numpy(val_dataset[80][np.newaxis, ...]).to(local_rank)
print("Shape of the first item:", first_batch.shape)

Shape of the first item: torch.Size([1, 2, 6, 3, 224, 224])


In [28]:
# run model on low coverage image
label_mask_batch = first_batch[:,1,:,:,:,:].to(local_rank)
batch = first_batch[:,0,:,:,:,:].to(local_rank)
loss, pred, mask = model(batch, label_mask_batch, 0.75)

In [29]:
# we un-normalize and re-normalize to reflectance values with scaling factor normalization
# the scaling factor for hls data is 0.0001
# we use torch.ceil() to avoid floating point errors resulting in negative values
input = torch.ceil((batch.detach() * std) + mean) * 0.0001
input_mask = label_mask_batch.detach()
predicted = torch.ceil(model.unpatchify(pred).detach() * std + mean) * 0.0001

# use masks to create input, predicted, and non-cloud tensors where values we don't need are set as 0
input_masked = input * input_mask
predicted_masked = predicted * input_mask
non_cloud = input * (1-input_mask)

# send tensors to numpy
input_masked = input_masked.cpu().numpy()
predicted_masked = predicted_masked.cpu().numpy()
non_cloud = non_cloud.cpu().numpy()

# set values that are masked out to nan so the are not counted in visualizations
input_masked[input_masked == 0] = np.nan
predicted_masked[predicted_masked == 0] = np.nan
non_cloud[non_cloud == 0] = np.nan

In [30]:
### creating a pairgrid for the low coverage image

# putting the data into a dataframe where each column represents an ordered list of band values
# the tensors are in format (Batch, Channel, Time Step, Height, Width)
# therefore, we select the first batch of 1, each channel sequentially, the second time step, and all pixels within H and W
non_cloud_data = pd.DataFrame({
    'B2': non_cloud[0,0,:,:,:].flatten(),
    'B3': non_cloud[0,1,:,:,:].flatten(),
    'B4': non_cloud[0,2,:,:,:].flatten(),
    'B5': non_cloud[0,3,:,:,:].flatten(),
    'B7': non_cloud[0,4,:,:,:].flatten(),
    'B8': non_cloud[0,5,:,:,:].flatten()
})
gen_data = pd.DataFrame({
    'B2': predicted_masked[0,0,:,:,:].flatten(),
    'B3': predicted_masked[0,1,:,:,:].flatten(),
    'B4': predicted_masked[0,2,:,:,:].flatten(),
    'B5': predicted_masked[0,3,:,:,:].flatten(),
    'B7': predicted_masked[0,4,:,:,:].flatten(),
    'B8': predicted_masked[0,5,:,:,:].flatten()
})
true_data = pd.DataFrame({
    'B2': input_masked[0,0,:,:,:].flatten(),
    'B3': input_masked[0,1,:,:,:].flatten(),
    'B4': input_masked[0,2,:,:,:].flatten(),
    'B5': input_masked[0,3,:,:,:].flatten(),
    'B7': input_masked[0,4,:,:,:].flatten(),
    'B8': input_masked[0,5,:,:,:].flatten()
})

# define 40 regular bin edges from 0 to 1 in increments of 0.025
bin_edges = [round(i * 0.025, 3) for i in range(40)]

# first pairgrid: true data, low coverage
true_data_pairgrid = sns.PairGrid(true_data, diag_sharey=False)
true_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
true_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Ground Truth Pixels\nViT, {experiment_name}, Low Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_ground_truth_low_coverage.png', format='png')
plt.close()

# second pairgrid: generated data, low coverage
gen_data_pairgrid = sns.PairGrid(gen_data, diag_sharey=False)
gen_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
gen_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Generated Pixels\nViT, {experiment_name}, Low Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_generated_low_coverage.png', format='png')
plt.close()

# third pairgrid: non-masked data, low coverage
non_cloud_pairgrid = sns.PairGrid(non_cloud_data, diag_sharey=False)
non_cloud_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
non_cloud_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Non-Cloud Pixels\nViT, {experiment_name}, Low Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_non_cloud_low_coverage.png', format='png')
plt.close()

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fcc3f2b6140>>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


In [31]:
# get the high coverage image from the val dataset and ensure the shape is correct
second_batch = torch.from_numpy(val_dataset[1280][np.newaxis, ...]).to(local_rank)
print("Shape of the second item:", second_batch.shape)

Shape of the second item: torch.Size([1, 2, 6, 3, 224, 224])


In [13]:
# run model on high coverage image
label_mask_batch = second_batch[:,1,:,:,:,:].to(local_rank)
batch = second_batch[:,0,:,:,:,:].to(local_rank)
loss, pred, mask = model(batch, label_mask_batch, 0.75)

In [14]:
# we un-normalize and re-normalize to reflectance values with scaling factor normalization
input = torch.ceil((batch.detach() * std) + mean) * 0.0001
input_mask = label_mask_batch.detach()
predicted = torch.ceil(model.unpatchify(pred).detach() * std + mean) * 0.0001

# use masks to create input, predicted, and non-cloud tensors where values we don't need are set as 0
input_masked = input * input_mask
predicted_masked = predicted * input_mask
non_cloud = input * (1-input_mask)

# send tensors to numpy
input_masked = input_masked.cpu().numpy()
predicted_masked = predicted_masked.cpu().numpy()
non_cloud = non_cloud.cpu().numpy()

# set values that are masked out to nan so the are not counted in visualizations
input_masked[input_masked == 0] = np.nan
predicted_masked[predicted_masked == 0] = np.nan
non_cloud[non_cloud == 0] = np.nan

In [15]:
### creating a pairgrid for the low coverage image

# putting the data into a dataframe where each column represents an ordered list of band values
# the tensors are in format (Batch, Channel, Time Step, Height, Width)
# therefore, we select the first batch of 1, each channel sequentially, the second time step, and all pixels within H and W
non_cloud_data = pd.DataFrame({
    'B2': non_cloud[0,0,:,:,:].flatten(),
    'B3': non_cloud[0,1,:,:,:].flatten(),
    'B4': non_cloud[0,2,:,:,:].flatten(),
    'B5': non_cloud[0,3,:,:,:].flatten(),
    'B7': non_cloud[0,4,:,:,:].flatten(),
    'B8': non_cloud[0,5,:,:,:].flatten()
})
gen_data = pd.DataFrame({
    'B2': predicted_masked[0,0,:,:,:].flatten(),
    'B3': predicted_masked[0,1,:,:,:].flatten(),
    'B4': predicted_masked[0,2,:,:,:].flatten(),
    'B5': predicted_masked[0,3,:,:,:].flatten(),
    'B7': predicted_masked[0,4,:,:,:].flatten(),
    'B8': predicted_masked[0,5,:,:,:].flatten()
})
true_data = pd.DataFrame({
    'B2': input_masked[0,0,:,:,:].flatten(),
    'B3': input_masked[0,1,:,:,:].flatten(),
    'B4': input_masked[0,2,:,:,:].flatten(),
    'B5': input_masked[0,3,:,:,:].flatten(),
    'B7': input_masked[0,4,:,:,:].flatten(),
    'B8': input_masked[0,5,:,:,:].flatten()
})

# define 40 regular bin edges from 0 to 1 in increments of 0.025
bin_edges = [round(i * 0.025, 3) for i in range(40)]

# first pairgrid: true data, high coverage
true_data_pairgrid = sns.PairGrid(true_data, diag_sharey=False)
true_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
true_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Ground Truth Pixels\nViT, {experiment_name}, Full Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_ground_truth_full_coverage.png', format='png')
plt.close()

# second pairgrid: generated data, high coverage
gen_data_pairgrid = sns.PairGrid(gen_data, diag_sharey=False)
gen_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
gen_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Generated Pixels\nViT, {experiment_name}, Full Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_generated_full_coverage.png', format='png')
plt.close()

# third pairgrid: non-masked data, low coverage
non_cloud_pairgrid = sns.PairGrid(non_cloud_data, diag_sharey=False)
non_cloud_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
non_cloud_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Non-Cloud Pixels\nViT, {experiment_name}, Full Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_non_cloud_full_coverage.png', format='png')
plt.close()

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


In [32]:
# setting up the dataset sampler the same as during training, with a batch size of 1
val_sampler = torch.utils.data.SequentialSampler(val_dataset)
test_kwargs = {"batch_size": 1, "sampler": val_sampler}
common_kwargs = {
        "pin_memory": False,
        "drop_last": True
    }
test_kwargs.update(common_kwargs)
test_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
print(len(test_loader))

1621


In [36]:
# adding pixel values to tensors iteratively

# initialize empty tensors to which we will concatenate pixel values for each band
# as we have 6 bands, we initialize empty tensors of size (6, 0)
# the tensors will be (Channel, B*H*W)
true_pixels = torch.empty((6, 0)).to(local_rank)
gen_pixels = torch.empty((6, 0)).to(local_rank)
non_cloud_pixels = torch.empty((6, 0)).to(local_rank)

model.eval()

for idx, batch in tqdm(enumerate(test_loader), initial = 1, total=64):
    # get mask batches from dataset
    label_mask_batch = batch[:,1,:,:,:,:].to(local_rank)

    # get input image batches from dataset
    batch = batch[:,0,:,:,:,:].to(local_rank)
    
    # run model
    loss, pred, mask = model(batch, label_mask_batch, 0.75)

    # once again, we un-normalize and re-normalize to reflectance values with scaling factor normalization    
    input = torch.ceil((batch.detach() * std) + mean) * 0.0001
    input_mask = label_mask_batch.detach()
    predicted = torch.ceil(model.unpatchify(pred).detach() * std + mean) * 0.0001
    
    # use masks to create input, predicted, and non-cloud tensors
    input_masked = input * input_mask
    predicted_masked = predicted * input_mask
    non_cloud = input * (1-input_mask)
    print(non_cloud.shape)
    # use view() to create tensors of shape (Channels, B*H*W) which represents (Channel, Pixel values for entire image)
    non_cloud_pixels_data = non_cloud[0,:,:,:,:].view(6,-1)
    gen_pixels_data = predicted_masked[0,:,:,:,:].view(6,-1)
    true_pixels_data = input_masked[0,:,:,:,:].view(6,-1)

    # concatenate the pixel values of this batch with the running tensors of pixel values
    true_pixels = torch.cat((true_pixels, true_pixels_data), dim=1)
    gen_pixels = torch.cat((gen_pixels, gen_pixels_data), dim=1)
    non_cloud_pixels = torch.cat((non_cloud_pixels, non_cloud_pixels_data), dim=1) 
    
    # at the last iteration, send tensors to numpy and set 0 values to np.nan
    if idx + 1 == 64:
        true_pixels = true_pixels.cpu().numpy()
        gen_pixels = gen_pixels.cpu().numpy()
        non_cloud_pixels = non_cloud_pixels.cpu().numpy()
        true_pixels[true_pixels == 0] = np.nan
        gen_pixels[gen_pixels == 0] = np.nan
        non_cloud_pixels[non_cloud_pixels == 0] = np.nan
        break

  8%|██████▌                                                                             | 5/64 [00:00<00:03, 17.33it/s]

torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])


 14%|███████████▊                                                                        | 9/64 [00:00<00:03, 16.19it/s]

torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])


 20%|████████████████▊                                                                  | 13/64 [00:00<00:03, 16.18it/s]

torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])


 27%|██████████████████████                                                             | 17/64 [00:00<00:02, 15.99it/s]

torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])


 33%|███████████████████████████▏                                                       | 21/64 [00:01<00:02, 15.76it/s]

torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])


 39%|████████████████████████████████▍                                                  | 25/64 [00:01<00:02, 16.13it/s]

torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])


 45%|█████████████████████████████████████▌                                             | 29/64 [00:01<00:02, 15.82it/s]

torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])


 52%|██████████████████████████████████████████▊                                        | 33/64 [00:01<00:01, 16.01it/s]

torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])


 58%|███████████████████████████████████████████████▉                                   | 37/64 [00:02<00:01, 15.64it/s]

torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])


 64%|█████████████████████████████████████████████████████▏                             | 41/64 [00:02<00:01, 15.78it/s]

torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])


 70%|██████████████████████████████████████████████████████████▎                        | 45/64 [00:02<00:01, 16.01it/s]

torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])


 77%|███████████████████████████████████████████████████████████████▌                   | 49/64 [00:03<00:00, 15.88it/s]

torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])
torch.Size([1, 6, 3, 224, 224])


 77%|███████████████████████████████████████████████████████████████▌                   | 49/64 [00:03<00:00, 15.83it/s]


KeyboardInterrupt: 

In [29]:
# putting these values in a dataframe

non_cloud_data = pd.DataFrame({
    'B2': non_cloud_pixels[0],
    'B3': non_cloud_pixels[1],
    'B4': non_cloud_pixels[2],
    'B5': non_cloud_pixels[3],
    'B7': non_cloud_pixels[4],
    'B8': non_cloud_pixels[5]
})
gen_data = pd.DataFrame({
    'B2': gen_pixels[0],
    'B3': gen_pixels[1],
    'B4': gen_pixels[2],
    'B5': gen_pixels[3],
    'B7': gen_pixels[4],
    'B8': gen_pixels[5]
})
true_data = pd.DataFrame({
    'B2': true_pixels[0],
    'B3': true_pixels[1],
    'B4': true_pixels[2],
    'B5': true_pixels[3],
    'B7': true_pixels[4],
    'B8': true_pixels[5]
})

In [30]:
# define 40 regular bin edges from 0 to 1 in increments of 0.025
bin_edges = [round(i * 0.025, 3) for i in range(40)]

# first pairgrid: true data, all pixels in 192 images
true_data_pairgrid = sns.PairGrid(true_data, diag_sharey=False)
true_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
true_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Ground Truth Pixels\nViT, {experiment_name}, 64 Test Images', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_ground_truth_all.png', format='png')
plt.close()

# second pairgrid: generated data, all pixels in 192 images
gen_data_pairgrid = sns.PairGrid(gen_data, diag_sharey=False)
gen_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
gen_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Generated Pixels\nViT, {experiment_name}, 64 Test Images', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_generated_all.png', format='png')
plt.close()

# third pairgrid: non cloud data, all pixels in 192 images
non_cloud_pairgrid = sns.PairGrid(non_cloud_data, diag_sharey=False)
non_cloud_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
non_cloud_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Non-Cloud Pixels\nViT, {experiment_name}, 64 Test Images', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_non_cloud_all.png', format='png')
plt.close()

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
