In [1]:
import os
import torch
import mae.models_mae
from mae.models_mae import MaskedAutoencoderViT
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
from pathlib import Path
import tqdm
from mae_training import CombinedDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# set local rank and define mean and std tensors for normalization purposes
local_rank = 1
mean = torch.tensor([495.7316,  814.1386,  924.5740, 2962.5623, 2640.8833, 1740.3031])[None,:,None,None,None].to(local_rank)
std = torch.tensor([286.9569, 359.3304, 576.3471, 892.2656, 945.9432, 916.1625])[None,:,None,None,None].to(local_rank)

In [3]:
# define the function that prepares the model based on the checkpoint and the correct architecture
def prepare_model(checkpoint, arch='mae_vit_base_patch16'):
        # build model
        model = getattr(mae.models_mae, arch)()
        # load model
        checkpoint_file = torch.load(checkpoint, map_location=f'cuda:{local_rank}')
        msg = model.load_state_dict(checkpoint_file, strict=False)
        print(msg)
        return model

In [4]:
# define which checkpoint, name of experiment, and save directory we will use
experiment_name = "Zero-Shot"
checkpoint = Path("/workspace/gfm-gap-filling/pretraining/epoch-832-loss-0.0473.pt")
save_dir = Path("/workspace/data/lchu/hls/vis/zero_shot_visualization")

In [5]:
# prepare model
model = prepare_model(checkpoint, 'mae_vit_base_patch16')
print('Model loaded.')

<All keys matched successfully>
Model loaded.


In [6]:
# define validation dataset
val_dataset = CombinedDataset("/workspace/gfm-gap-filling/pretraining/train_single_band/train_single_band", split="validate", num_frames=3, img_size=224, bands=6, cloud_range=[0.01,1.0],
                              # random_cropping=random_cropping, remove_cloud=True, 
                               normalize=True)

In [7]:
# ensure the length and number of masks are correct
print(f"--> Validation set len = {len(val_dataset)}")
print(f"--> Validation set masks = {val_dataset.n_cloudpaths}")

--> Validation set len = 1621
--> Validation set masks = 1600


In [8]:
# send model to device at local rank
torch.cuda.set_device(local_rank)
model = model.to(torch.cuda.current_device())

In [9]:
# get the low coverage image from the val dataset
first_batch = torch.from_numpy(val_dataset[640][np.newaxis, ...]).to(local_rank)
print("Shape of the first item:", first_batch.shape)

Shape of the first item: torch.Size([1, 2, 6, 3, 224, 224])


In [10]:
# run model on low coverage image
label_mask_batch = first_batch[:,1,:,:,:,:].to(local_rank)
batch = first_batch[:,0,:,:,:,:].to(local_rank)
loss, pred, mask = model(batch, label_mask_batch, 0.75)

In [11]:
# we un-normalize and re-normalize to reflectance values with scaling factor normalization
input = torch.ceil((batch.detach() * std) + mean) * 0.0001
input_mask = label_mask_batch.detach()
predicted = torch.ceil(model.unpatchify(pred).detach() * std + mean) * 0.0001
input_masked = input * input_mask
predicted_masked = predicted * input_mask
non_cloud = input * (1-input_mask)

input_masked = input_masked.cpu().numpy()
predicted_masked = predicted_masked.cpu().numpy()
non_cloud = non_cloud.cpu().numpy()

input_masked[input_masked == 0] = np.nan
predicted_masked[predicted_masked == 0] = np.nan
non_cloud[non_cloud == 0] = np.nan

In [12]:
### creating a pairgrid for the low coverage image

# putting the data into a dataframe where each column represents an ordered list of band values
non_cloud_data = pd.DataFrame({
    'B2': non_cloud[0,0,1,:,:].flatten(),
    'B3': non_cloud[0,1,1,:,:].flatten(),
    'B4': non_cloud[0,2,1,:,:].flatten(),
    'B5': non_cloud[0,3,1,:,:].flatten(),
    'B7': non_cloud[0,4,1,:,:].flatten(),
    'B8': non_cloud[0,5,1,:,:].flatten()
})
gen_data = pd.DataFrame({
    'B2': predicted_masked[0,0,1,:,:].flatten(),
    'B3': predicted_masked[0,1,1,:,:].flatten(),
    'B4': predicted_masked[0,2,1,:,:].flatten(),
    'B5': predicted_masked[0,3,1,:,:].flatten(),
    'B7': predicted_masked[0,4,1,:,:].flatten(),
    'B8': predicted_masked[0,5,1,:,:].flatten()
})
true_data = pd.DataFrame({
    'B2': input_masked[0,0,1,:,:].flatten(),
    'B3': input_masked[0,1,1,:,:].flatten(),
    'B4': input_masked[0,2,1,:,:].flatten(),
    'B5': input_masked[0,3,1,:,:].flatten(),
    'B7': input_masked[0,4,1,:,:].flatten(),
    'B8': input_masked[0,5,1,:,:].flatten()
})
bin_edges = [round(i * 0.05, 2) for i in range(20)]

true_data_pairgrid = sns.PairGrid(true_data, diag_sharey=False)
true_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
true_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Ground Truth Pixels\nViT, {experiment_name}, Low Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_ground_truth_low_coverage.png', format='png')
plt.close()

gen_data_pairgrid = sns.PairGrid(gen_data, diag_sharey=False)
gen_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
gen_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Generated Pixels\nViT, {experiment_name}, Low Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_generated_low_coverage.png', format='png')
plt.close()

non_cloud_pairgrid = sns.PairGrid(non_cloud_data, diag_sharey=False)
non_cloud_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
non_cloud_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Non-Cloud Pixels\nViT, {experiment_name}, Low Coverage', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_non_cloud_low_coverage.png', format='png')
plt.close()

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


In [13]:
# Setting up the dataset sampler the same as during training
val_sampler = torch.utils.data.SequentialSampler(val_dataset)
test_kwargs = {"batch_size": 1, "sampler": val_sampler}
common_kwargs = {
        "pin_memory": False,
        "drop_last": True
    }
test_kwargs.update(common_kwargs)
test_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
print(len(test_loader))

1621


In [14]:
# Adding pixel values to tensors iteratively
true_pixels = torch.empty((6, 0)).to(local_rank)
gen_pixels = torch.empty((6, 0)).to(local_rank)
non_cloud_pixels = torch.empty((6, 0)).to(local_rank)

model.eval()

for i, batch in enumerate(test_loader):
    # get mask batches from dataset
    label_mask_batch = batch[:,1,:,:,:,:].to(local_rank)

    # get input image batches from dataset
    batch = batch[:,0,:,:,:,:].to(local_rank)
    
    loss, pred, mask = model(batch, label_mask_batch, 0.75)

    input = torch.ceil((batch.detach() * std) + mean) * 0.0001
    input_mask = label_mask_batch.detach()
    predicted = torch.ceil(model.unpatchify(pred).detach() * std + mean) * 0.0001
    input_masked = input * input_mask
    predicted_masked = predicted * input_mask
    non_cloud = input * (1-input_mask)

    non_cloud_pixels_data = non_cloud[0,:,1,:,:].view(6,-1)
    gen_pixels_data = predicted_masked[0,:,1,:,:].view(6,-1)
    true_pixels_data = input_masked[0,:,1,:,:].view(6,-1)

    true_pixels = torch.cat((true_pixels, true_pixels_data), dim=1)
    gen_pixels = torch.cat((gen_pixels, gen_pixels_data), dim=1)
    non_cloud_pixels = torch.cat((non_cloud_pixels, non_cloud_pixels_data), dim=1) 
    
    if i + 1 == 200:
        true_pixels = true_pixels.cpu().numpy()
        gen_pixels = gen_pixels.cpu().numpy()
        non_cloud_pixels = non_cloud_pixels.cpu().numpy()
        true_pixels[true_pixels == 0] = np.nan
        gen_pixels[gen_pixels == 0] = np.nan
        non_cloud_pixels[non_cloud_pixels == 0] = np.nan
        break

In [15]:
# Putting these values in a dataframe

non_cloud_data = pd.DataFrame({
    'B2': non_cloud_pixels[0],
    'B3': non_cloud_pixels[1],
    'B4': non_cloud_pixels[2],
    'B5': non_cloud_pixels[3],
    'B7': non_cloud_pixels[4],
    'B8': non_cloud_pixels[5]
})
gen_data = pd.DataFrame({
    'B2': gen_pixels[0],
    'B3': gen_pixels[1],
    'B4': gen_pixels[2],
    'B5': gen_pixels[3],
    'B7': gen_pixels[4],
    'B8': gen_pixels[5]
})
true_data = pd.DataFrame({
    'B2': true_pixels[0],
    'B3': true_pixels[1],
    'B4': true_pixels[2],
    'B5': true_pixels[3],
    'B7': true_pixels[4],
    'B8': true_pixels[5]
})

In [None]:
bin_edges = [round(i * 0.05, 2) for i in range(20)]

true_data_pairgrid = sns.PairGrid(true_data, diag_sharey=False)
true_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='red')
true_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
true_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Ground Truth Pixels\nViT, {experiment_name}, 200 Test Images', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_ground_truth_all.png', format='png')
plt.close()

gen_data_pairgrid = sns.PairGrid(gen_data, diag_sharey=False)
gen_data_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='blue')
gen_data_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
gen_data_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Generated Pixels\nViT, {experiment_name}, 200 Test Images', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_generated_all.png', format='png')
plt.close()

non_cloud_pairgrid = sns.PairGrid(non_cloud_data, diag_sharey=False)
non_cloud_pairgrid.map_lower(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.map_diag(sns.histplot, bins=bin_edges, color='green')
non_cloud_pairgrid.set(xlim=(0, 1), ylim=(0, 1))
non_cloud_pairgrid.fig.set_size_inches(10, 10)
plt.suptitle(f'Relationship Between Band Reflectance Values of Non-Cloud Pixels\nViT, {experiment_name}, 200 Test Images', fontsize=16)
plt.tight_layout()
plt.savefig(save_dir / 'band_correlations_non_cloud_all.png', format='png')
plt.close()

  plt.tight_layout()
