In [1]:
import os

os.environ['REQUESTS_CA_BUNDLE'] = 'L:\\repos\\worldstrat\\ca-certificates.crt'

In [None]:
from sentinelhub import SHConfig, SentinelHubRequest, MimeType, CRS, BBox, DataCollection
from eolearn.io import SentinelHubInputTask, SentinelHubEvalscriptTask
from eolearn.core import EOTask, EOWorkflow, FeatureType, OutputTask, SaveTask, linearly_connect_tasks

from torchvision.transforms import Compose, Resize, InterpolationMode, Normalize, Lambda
from src.lightning_modules import LitModel
import torch
import torch.nn.functional as F

from sklearn.preprocessing import minmax_scale

from tqdm import tqdm

from src.datasources import (
    S2_ALL_12BANDS,
    SN7_SUBDIRECTORIES,
    S2_SN7_MEAN,
    S2_SN7_STD,
    SN7_BANDS_TO_READ,
    SN7_MAX_EXPECTED_HR_VALUE,
    SPOT_RGB_BANDS,
    JIF_S2_MEAN,
    JIF_S2_STD,
    S2_ALL_BANDS,
    SPOT_MAX_EXPECTED_VALUE_8_BIT,
    SPOT_MAX_EXPECTED_VALUE_12_BIT,
    ROOT_JIF_DATA_TRAIN,
    METADATA_PATH,
)

import matplotlib.pyplot as plt


import datetime as dt

import numpy as np

# Define your area of interest and time range
bbox = BBox(bbox= [
        -3.693118031193279,
        40.403749979462816,
        -3.6664122696536765,
        40.42189244677638
      ], crs=CRS.WGS84)
time_interval = ('2020-03-12', '2020-06-13')

config = SHConfig()
config.sh_client_id = 'c93bbe9c-f393-4ac1-b862-70ca42f1be14'#getpass.getpass('Client Id')
config.sh_client_secret = 'PqxiE6rWyrldjkr1yIjq0anJHvI6nLdu' #getpass.getpass('Client Secret')
config.save()
resolution=10
cache_folder='data'

# Create a SentinelHubRequest
request = SentinelHubInputTask(
    data_collection=DataCollection.SENTINEL2_L2A,
    bands_feature=(FeatureType.DATA, "L2A_data"),
    additional_data=[(FeatureType.MASK, "dataMask"),(FeatureType.MASK, "CLM"),(FeatureType.MASK, "SCL")],
    resolution=resolution,
    maxcc=1,
    time_difference=dt.timedelta(hours=2),
    cache_folder = cache_folder,
)

response = request.execute(bbox=bbox,time_interval=time_interval)



In [None]:
data = response.data['L2A_data']
np.shape(data)

In [None]:
plt.imshow(np.transpose(data[27,:,:,[3,2,1]], (1,2,0)))

#[1,3,7,9,21,22,23,27]

In [None]:
transforms = {}
input_size=(160, 160)
output_size=(1054, 1054)
interpolation=InterpolationMode.BICUBIC
normalize_lr=True
scene_classification_to_color=False
radiometry_depth=12

lr_bands_to_use = np.array(S2_ALL_BANDS) - 1
normalize = Normalize(
            mean=JIF_S2_MEAN[lr_bands_to_use], std=JIF_S2_STD[lr_bands_to_use]
        )
transforms["lr"] = Compose(
        [
            Lambda(lambda lr_revisit: torch.as_tensor(lr_revisit)),
            normalize,
            Resize(size=input_size, interpolation=interpolation, antialias=True),
        ]
    )

transforms["lrc"] = Compose(
    [
        Lambda(
            lambda lr_scene_classification: torch.as_tensor(lr_scene_classification)
        ),
        # Categorical
        Resize(size=input_size, interpolation=InterpolationMode.NEAREST),
        # Categorical to RGB; NOTE: interferes with FilterData
        SceneClassificationToColorTransform
        if scene_classification_to_color
        else Compose([]),
    ]
)

In [None]:
def load_model(checkpoint, device):
    """ Loads a model from a checkpoint.

    Parameters
    ----------
    checkpoint : str
        Path to the checkpoint.

    Returns
    -------
    model : lightning_modules.LitModel
        The model.
    """    
    model = LitModel.load_from_checkpoint(checkpoint).eval()
    return model.to(device)

def bias_adjust(y_hat, y):
    """ Adjust the bias of the output of the model.

    Parameters
    ----------
    y_hat : torch.Tensor
        The output of the model (super-resolved image).
    y : torch.Tensor
        The ground truth (high-resolution image).

    Returns
    -------
    y_hat : torch.Tensor
        The output of the model (super-resolved image) with bias adjusted.
    """    
    b = (y - y_hat).mean(dim=(-1, -2), keepdim=True)
    return y_hat + b

def infer_chip(input_chip, model):
    input_data = np.transpose(input_chip, (0, 3, 1, 2))
    input_data = np.expand_dims(input_data, axis=0)  
    input_tensor= torch.tensor(input_data, dtype=torch.float32)
    input_tensor.shape
    
    # Ensure input_tensor has shape [batch_size, channels, height, width]
    input_tensor = input_tensor.squeeze(0)  # Remove batch dimension temporarily
    #input_tensor = input_tensor.transpose(0, 1)  # Change to [channels, bands, height, width]
    #input_tensor = input_tensor.view(-1, input_tensor.shape[2], input_tensor.shape[3])  # Flatten to [channels*batches, height, width]
    
    # Apply the transformations
    transformed_input = transforms["lr"](input_tensor)
    
    # Now `transformed_input` contains the transformed data
    #print(transformed_input.shape)  # Debug print to check the shape
    
    
    # Reshape the output back to the original format if needed
    transformed_input = transformed_input.view(1, 8, 12, *transformed_input.shape[2:])
    transformed_input.shape

    #transformed_input = np.expand_dims(transformed_input, axis=0) 
    y = model(transformed_input)#.detach().numpy()

    np.shape(transformed_input)
    output_tensor = F.interpolate(transformed_input[:,0,[3,2,1],:,:], size=( 156, 156), mode='bilinear', align_corners=False)
    np.shape(output_tensor)
    
    b = (output_tensor - y).mean(dim=(-1, -2), keepdim=True)
    y = bias_adjust(y, output_tensor).detach().numpy()
    
    y_numpy = np.squeeze(y).transpose(1, 2, 0)
    return y_numpy

In [None]:
def extract_chips(image, chip_size):
    """
    Extract non-overlapping chips of size chip_size from the input image.
    Args:
        image (numpy.ndarray): Input image of shape [bands, height, width, channels]
        chip_size (tuple): Size of each chip (chip_height, chip_width)
    Returns:
        list: List of chips
        list: List of positions of chips in the format (i, j)
    """
    bands, height, width, channels = image.shape
    chip_height, chip_width = chip_size
    chips = []
    positions = []

    for i in range(0, height, chip_height):
        for j in range(0, width, chip_width):
            chip = np.full((bands, chip_height, chip_width, channels), np.nan)
            actual_height = min(chip_height, height - i)
            actual_width = min(chip_width, width - j)
            chip[:, :actual_height, :actual_width, :] = image[:, i:i+actual_height, j:j+actual_width, :]
            chips.append(chip)
            positions.append((i, j))

    return chips, positions

def recompose_image(chips, positions, image_shape, chip_size, output_size):
    """
    Recompose the super-resolved chips back into the full image.
    Args:
        chips (list): List of super-resolved chips
        positions (list): List of positions of chips in the format (i, j)
        image_shape (tuple): Original shape of the image (bands, height, width, channels)
        chip_size (tuple): Size of each input chip (chip_height, chip_width)
        output_size (tuple): Size of each super-resolved chip (output_height, output_width)
    Returns:
        numpy.ndarray: Recomposed super-resolved image
    """
    _, original_height, original_width, original_channels = image_shape
    chip_height, chip_width = chip_size
    output_height, output_width = output_size

    recomposed_height = (original_height // chip_height) * output_height
    recomposed_width = (original_width // chip_width) * output_width
    recomposed_image = np.zeros((recomposed_height, recomposed_width, np.shape(chips[0])[2]))
    count_matrix = np.zeros((recomposed_height, recomposed_width, np.shape(chips[0])[2]))

    for k, (i, j) in enumerate(positions):
        chip = chips[k]
        if np.shape(chip)[0] != output_height or np.shape(chip)[1] != output_width:
            # Pad the chip to the required size if it's smaller
            pad_height = output_height - chip.shape[0]
            pad_width = output_width - chip.shape[1]
            chip = np.pad(chip, ((0, pad_height), (0, pad_width), (0, 0)), mode='constant')
        recomposed_image[i//chip_height*output_height:i//chip_height*output_height+output_height, 
                         j//chip_width*output_width:j//chip_width*output_width+output_width, :] += chip
        count_matrix[i//chip_height*output_height:i//chip_height*output_height+output_height, 
                     j//chip_width*output_width:j//chip_width*output_width+output_width, :] += 1

    recomposed_image /= count_matrix
    return recomposed_image

In [None]:
def save_image(nparray, path):
    from PIL import Image
    
    print(np.shape(nparray))
    formatted = (nparray * 255 / np.max(nparray)).astype('uint8')
    im = Image.fromarray(formatted)
    im.save(path)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Uncomment to use the CPU
device = 'cpu'

# Uncomment to use the GPU
# device = 'cuda'

#print('Using device:', device)
#print()
model = load_model('pretrained_model/model.ckpt', device)

In [None]:
input_data = data[[1,3,7,9,21,22,23,27],:,:,:]
chip_size = (20,20)
output_size = (156,156)
# Extract chips from the input image
chips, positions =  extract_chips(input_data, chip_size)

# Transform chips and apply the model
transformed_chips = []
for i,chip in enumerate(tqdm(chips)):
    #fig, ax = plt.subplots(1,2, figsize=(150, 150))
    
    #chip = torch.tensor(chip, dtype=torch.float32)
    #chip = chip.permute(3, 0, 1, 2)  # Change to [channels, bands, height, width]
    #chip = chip.view(-1, chip.shape[2], chip.shape[3])  # Flatten to [channels*batches, height, width]
    super_resolved_chip = infer_chip(chip, model) 
    #print(np.shape(super_resolved_chip))# Apply the model
    #super_resolved_chip = super_resolved_chip.view(-1, 4, output_size[0], output_size[1])  # Reshape back
    super_resolved_chip = [minmax_scale(band, feature_range=(0,0.5)) for band in super_resolved_chip]
    transformed_chips.append(super_resolved_chip)  # Change to [bands, height, width, channels]
    #ax[0].imshow(chip[0,:,:,:][:,:,[3,2,1]])
    #ax[1].imshow(super_resolved_chip)
    #plt.show()

'''
    save_image(chip[0,:,:,:][:,:,[3,2,1]],f'output/original_{i}.png')
    save_image(super_resolved_chip,f'output/superresolved_{i}.png')
'''

In [None]:
# Recompose the super-resolved chips back into the full image
#super_resolved_image = recompose_image(transformed_chips,  np.shape(input_data), input_size, output_size)

#print(np.shape(super_resolved_image))  # Debug print to check the shape

In [None]:
_, original_height, original_width, original_channels = np.shape(input_data)
chip_height, chip_width = chip_size
output_height, output_width = output_size
recomposed_height = (original_height // chip_height) * output_height
recomposed_width = (original_width // chip_width) * output_width
recomposed_image = np.zeros((recomposed_height, recomposed_width, np.shape(transformed_chips[0])[2]))
count_matrix = np.zeros((recomposed_height, recomposed_width, np.shape(transformed_chips[0])[2]))


k = 0
for i in range(0, recomposed_height, output_height):
    for j in range(0, recomposed_width, output_width):
        if k < len(transformed_chips):
            chip = transformed_chips[k]
            recomposed_image[i:i+output_height, j:j+output_width, :] += chip
            count_matrix[i:i+output_height, j:j+output_width, :] += 1
            k += 1

recomposed_image /= count_matrix
super_resolved_image = recomposed_image

In [None]:
np.shape(super_resolved_image)

In [None]:
import cv2


formatted = (super_resolved_image * 255 / np.max(super_resolved_image)).astype('uint8')
print(np.shape(formatted))
cv2.imwrite("output/super_resolved_image.jpeg", formatted)

scaled_out = [minmax_scale(band, feature_range=(0,0.5)) for band in super_resolved_image]
formatted = (scaled_out * 255 / np.max(scaled_out)).astype('uint8')
print(np.shape(formatted))
cv2.imwrite("output/scaled_super_resolved_image.jpeg", formatted)

indata = input_data[0,:,:,:][:,:,[3,2,1]]
formatted = (indata * 255 / np.max(indata)).astype('uint8')
print(np.shape(formatted))
cv2.imwrite("output/original.jpeg", formatted)

In [None]:
import matplotlib.pyplot as plt
from sklearn.preprocessing import minmax_scale

y_res = np.squeeze(super_resolved_image)
y_res = [minmax_scale(band, feature_range=(0,0.5)) for band in y_res]
x_res = indata
x_res = [minmax_scale(band, feature_range=(0,0.5)) for band in x_res]

fig, ax = plt.subplots(1,2, figsize=(150, 150))
ax[0].imshow(x_res)
ax[1].imshow(y_res)