In [None]:
%load_ext autoreload
%autoreload 2

## Imports

In [None]:
import pickle
import time
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.linear_model import LinearRegression
from torch.utils.data import DataLoader
from torch.nn.functional import interpolate
from torcheval.metrics.functional import peak_signal_noise_ratio
from torchmetrics.functional.image import structural_similarity_index_measure

from super_resolution.src.sen2venus_dataset import (
    S2VSite,
    S2VSites,
    create_train_test_split,
)
from super_resolution.src.visualization import plot_gallery

## Constants

In [None]:
DATA_DIR = pathlib.Path("C:/Users/Mitch/stat3007_data")
SITES_DIR = DATA_DIR / "sites"
BICUBIC_DIR = DATA_DIR / "bicubic_results"

TOTAL_SAMPLES = 132_955

## Test standard metrics on bicubic interpolation

In [None]:
site = S2VSite(
    site_name="FGMANAUS",
    bands="rgbnir",
    download_dir=str(DATA_DIR / "sites") + "\\",
    device="cpu",
)
print(f"{len(site)} patches")

In [None]:
def scale_images(images: torch.Tensor) -> torch.Tensor:
    min_val = images.min()
    max_val = images.max()

    return (images - min_val) / (max_val - min_val)


def preprocess_images(images: torch.Tensor, scale_output: bool = True) -> torch.Tensor:
    images = images[:, :3, :, :]
    if scale_output:
        images = scale_images(images)
    return images

In [None]:
X = preprocess_images(
    torch.concat([sentinal_image.unsqueeze(0) for sentinal_image, _ in site]),
    scale_output=False,
)

In [None]:
Y_target = preprocess_images(
    torch.concat([venus_image.unsqueeze(0) for _, venus_image in site]),
    scale_output=False,
)

In [None]:
interpolated_X = interpolate(X, size=(256, 256), mode="bicubic")

In [None]:
psnr = peak_signal_noise_ratio(interpolated_X, Y_target)
ssim = structural_similarity_index_measure(interpolated_X, Y_target)
print(f"Metrics\nPSNR: {psnr}\nSSIM: {ssim}")

In [None]:
# Display example
spacing = 23
space = lambda num_space: num_space * " "
print(
    f"{space(spacing)}Sentinel{space(2*spacing)}Bicubic{space(2*spacing)}Venus{space(spacing)}"
)
for i in range(1):
    plot_gallery(
        [
            scale_images(X[i].permute(1, 2, 0)),
            scale_images(interpolated_X[i].permute(1, 2, 0)),
            scale_images(Y_target[i].permute(1, 2, 0)),
        ],
        xscale=10,
        yscale=10,
    )

## Testing Speed of Bicubic Interpolation

In [None]:
BATCH_SIZE = 512
NUM_SITES_TO_TRY = 1

In [None]:
site_names = [site_name for site_name, _ in S2VSites.SITES]

In [None]:
# Run bicubic interpolation on a cumulative number of sites.
cum_num_samples = []
cum_time = []
for i in range(NUM_SITES_TO_TRY):
    num_samples = 0
    start_time = time.time()

    # Prepare data
    train_data, test_data = create_train_test_split(
        str(DATA_DIR / "sites") + "\\", seed=42, sites=set(site_names[: i + 1])
    )
    train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE)
    test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE)

    # Start interpolatin
    for X, Y in train_dataloader:
        interpolated_X = interpolate(X, size=(256, 256), mode="bicubic")
        num_samples += X.shape[0]
        cum_num_samples.append(num_samples)
        cum_time.append(time.time() - start_time)
    for X, Y in test_dataloader:
        interpolated_X = interpolate(X, size=(256, 256), mode="bicubic")
        num_samples += X.shape[0]
        cum_num_samples.append(num_samples)
        cum_time.append(time.time() - start_time)

In [None]:
FILE_NAME = "samples2500_batchsize512"

In [None]:
samples_file = FILE_NAME + "_cumsamples.pkl"
time_file = FILE_NAME + "_cumtimes.pkl"

Save the results from the long running process as a pkl file.

In [None]:
# with open(BICUBIC_DIR / samples_file, "wb") as file:
#     pickle.dump(cum_num_samples, file)

# with open(BICUBIC_DIR / time_file, "wb") as file:
#     pickle.dump(cum_time, file)

In [None]:
with open(BICUBIC_DIR / samples_file, "rb") as file:
    cum_num_samples = pickle.load(file)

with open(BICUBIC_DIR / time_file, "rb") as file:
    cum_time = pickle.load(file)

Plot the results

In [None]:
plt.plot(cum_num_samples, cum_time)
plt.xlabel("Number of samples")
plt.ylabel("Time to complete bicubic interpolation (seconds)")
plt.title("Speed performance of bicubic interpolation on CPU")
plt.show()

Interpolate time on all samples.

In [None]:
lr = LinearRegression()
lr.fit(
    np.expand_dims(np.array(cum_num_samples), 1), np.expand_dims(np.array(cum_time), 1)
)
num_samples = np.expand_dims(np.linspace(0, 132_955), 1)
predicted_times = lr.predict(num_samples)

In [None]:
plt.plot(num_samples, predicted_times / 60**2, label="Predicted times")
plt.plot(cum_num_samples, np.array(cum_time) / 60**2, label="True times")
plt.xlabel("Number of samples")
plt.ylabel("Time to complete bicubic interpolation (hours)")
plt.title("Speed performance of bicubic interpolation on CPU")
plt.legend()
plt.show()

In [None]:
predicted_time = lr.predict(np.array([[TOTAL_SAMPLES]]))
print(
    f"Predicted {predicted_time[0, 0] / 60 / 60:.2f}hrs for running bicubic interpolation on all samples"
)