In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import verde as vd
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

import models
from mlnoddy.datasets import Norm, parse_geophysics
from test import reshape
from utils import to_pixel_samples


os.environ["CUDA_VISIBLE_DEVICES"] = "0"


In [None]:
norm = Norm(-10000, 10000, out_vals=(0, 1)).min_max_clip
unnorm = Norm(-10000, 10000, out_vals=(0, 1)).inverse_mmc


def subsample(raster, ls):
    """Select points from raster according to line spacing"""
    # input_cell_size = 20 # Noddyverse cell size is 20 m
    ss = 1  # Sample every n points along line

    x, y = np.meshgrid(
        np.arange(raster.shape[-1]), np.arange(raster.shape[-2]), indexing="xy"
    )
    x = x[::ss, ::ls]
    y = y[::ss, ::ls]
    vals = raster.numpy()[:, ::ss, ::ls].squeeze()  # shape for gridding

    return x, y, vals


def grid(x, y, z, ls, cs_fac=4, d=180):
    """Min Curvature grid xyz at scale, with ls/cs_fac cell size.
    Params:
        d: adjustable crop factor, but 180 is best for noddyverse. 200 Max.
    """
    w, e, s, n = np.array([0, d, 0, d], dtype=np.float32)
    cs = ls / cs_fac  # Cell size is e.g. 1/4 line spacing
    gridder = vd.ScipyGridder("cubic")
    gridder = gridder.fit(coordinates=(x, y), data=z)
    grid = gridder.grid(
        data_names="forward",
        coordinates=np.meshgrid(
            np.arange(w, e, step=cs),
            np.arange(s, n, step=cs),
            indexing="xy",
        ),
    )
    grid = grid.get("forward").values.astype(np.float32)

    return np.expand_dims(grid, 0)  # add channel dimension


In [None]:
model_path = "D:/luke/lte_geo/save/_train_swinir-lte_geo/230308-1738_joint_fly_346/joint_fly_346_epoch-last.pth"
# model_path = "D:/luke/lte_geo/save/_train_swinir-lte_geo/230517-0911_spotty_airport_6488/spotty_airport_6488_epoch-last.pth"
# model_path = "D:/luke/lte_geo/save/_train_swinir-lte_geo/230512-1658_extra_squid_2420/extra_squid_2420_epoch-last.pth"
# model_path = "D:/luke/lte_geo/save/_train_swinir-lte_geo/230516-1847_resident_crumble_3867/resident_crumble_3867_epoch-last.pth"
# fpath = "D:/luke/Noddy_data/Noddy_1M/DYKE_DYKE_DYKE/models_by_code/models/DYKE_DYKE_DYKE/20-09-04-16-04-59-857118037.g12"
fpath = "D:/luke/Noddy_data/Noddy_1M/DYKE_DYKE_FOLD/models_by_code/models/DYKE_DYKE_FOLD/20-09-04-18-53-46-989190706.g12"

# model_path = "D:/luke/edsr-baseline-lte.pth"

model_spec = torch.load(model_path)["model"]
model = models.make(model_spec, load_sd=True).cuda()

hr_ls = 4
scale = 4


In [None]:
which = 1

if which == 0:
    # Synthetic Fourier plot
    dim = 180
    im = np.zeros((3, dim, dim))
    s = dim // scale
    e = (dim - s) // 2
    # im[:, 128] = 1
    im[::2, :] = 1
    gt = torch.from_numpy(im)
    gt = gt.to(torch.float32)
    lr = gt[:, :, e:-e, e:-e]  # shit downsampling implementation
    # lr = lr.unsqueeze(0)


elif which == 1:
    # Noddy Model
    mag = parse_geophysics(Path(fpath), mag=True)
    mag = torch.from_numpy(next(mag)).unsqueeze(0)

    x, y, vals = subsample(mag, ls=hr_ls)
    gt = torch.from_numpy(norm(grid(x, y, vals, ls=hr_ls))).to(torch.float32)
    x, y, vals = subsample(mag, ls=hr_ls * scale)
    lr = torch.from_numpy(norm(grid(x, y, vals, ls=hr_ls * scale))).to(torch.float32)

elif which == 2:
    scale = 2
    lr_path = "D:/luke/Flickr_-_paul_bica_-_vanishing_point.jpg"
    xx = 360
    yy = 660
    obs_size = 120

    from torchvision import transforms
    from PIL import Image

    # load image
    gt = transforms.ToTensor()(Image.open(lr_path).convert('RGB'))
    gt = gt[:, 720-120:720+120, 1320-120:1320+120].contiguous()
    img_lr = transforms.Resize(gt.shape[1] // 2)(gt).contiguous()
    lr = (img_lr - 0.5) / 0.5

    from PIL import ImageDraw

    # Display GT
    im = transforms.ToPILImage()(lr)
    draw = ImageDraw.Draw(im)
    draw.rectangle([yy-obs_size//2, xx-obs_size//2, yy+obs_size//2, xx+obs_size//2], outline="red", width=3)
    display(im)


hr_coord, hr_val = to_pixel_samples(gt)

hr_cell = torch.ones_like(hr_coord)
hr_cell[:, 0] *= 2 / gt.shape[-2]
hr_cell[:, 1] *= 2 / gt.shape[-1]

gt = gt.unsqueeze(0)
lr = lr.unsqueeze(0)
hr_coord = hr_coord.unsqueeze(0)
hr_cell = hr_cell.unsqueeze(0)



In [None]:
print(gt.shape, lr.shape, hr_coord.shape, hr_cell.shape)


In [None]:
# evaluation
inp = lr.to("cuda", non_blocking=True)

model.eval()
with torch.no_grad():
    # model.gen_feat(
    #     inp.flip(-2)
    # )  # due to a digital image coordinate conventions (https://blogs.mathworks.com/steve/2011/08/26/digital-image-processing-using-matlab-digital-image-representation/)

    sr = (
        model(
            inp.flip(-2),
            hr_coord.to("cuda", non_blocking=True),
            hr_cell.to("cuda", non_blocking=True),
        )
        .detach()
        .cpu()
    )
    freq = model.coeff.flip(-2)
    coef = model.freqq.flip(-2)

sr, batch = reshape(dict(inp=lr, coord=hr_coord, gt=gt), 0, 0, hr_coord, sr)

In [None]:
print(freq.shape)
print(len(torch.split(freq, 2, dim=1)))
print(torch.split(freq, 2, dim=1)[0].shape)
print(torch.stack(torch.split(freq, 2, dim=1), dim=2).shape)

freq_x = torch.stack(torch.split(freq, 2, dim=1), dim=2)[0, 1, :, 0, 0]
freq_y = torch.stack(torch.split(freq, 2, dim=1), dim=2)[0, 0, :, 0, 0]


In [None]:
arr = (
    # (((torch.stack((gt, gt, gt)) + 1) * 128) - 1)
    # (torch.stack((gt, gt, gt)))
    gt.squeeze()
    # .permute(1, 2, 0)
    .numpy().astype("uint8")
)

vmin = sr.min()
vmax = gt.max()

fig, [axlr, axsr, axhr] = plt.subplots(1, 3, constrained_layout=True, figsize=(15, 5))
plt.suptitle(Path(model_path).stem.split("_epoch")[0])
srh = axhr.imshow(gt.squeeze(), vmin=vmin, vmax=vmax) # .permute(1,2,0)
# plt.colorbar(srh, ax=axhr)
axlr.imshow(lr.squeeze(), vmin=vmin, vmax=vmax) # .permute(1,2,0) * 0.5 + 0.5
src = axsr.imshow(sr.squeeze(), vmin=vmin, vmax=vmax, origin="lower") #.permute(1,2,0) * 0.5 + 0.5
# plt.colorbar(src, ax=axsr)
axhr.set_title("hr")
axlr.set_title("lr")
axsr.set_title("sr")
plt.show()


In [None]:
# Display Fourier Feature Space
# freq = model.freq(model.feat)
# coef = model.coef(model.feat)


def plot_F_features(freq, coef):
    """Plot LTE extracted Fourier Features"""
    freq_x = torch.stack(torch.split(freq, 2, dim=1), dim=2)[0, 1, :, 0, 0]
    freq_y = torch.stack(torch.split(freq, 2, dim=1), dim=2)[0, 0, :, 0, 0]
    mag = (
        coef[0, : freq.shape[1] // 2, 0, 0] ** 2
        + coef[0, freq.shape[1] // 2 :, 0, 0] ** 2
    )
    plt.figure(figsize=(6, 6), constrained_layout=True)
    plt.title(Path(model_path).stem.split("_epoch")[0])
    sc = plt.scatter(
        freq_x.cpu().numpy(),
        freq_y.cpu().numpy(),
        c=mag.cpu().numpy(),
        vmin=0,
        vmax=max(mag.cpu().numpy()) / 4,
        s=25,
        # alpha=0.5,
        linewidths=0,
        cmap="bwr",
    )
    # plt.colorbar(sc)
    plt.xticks(np.linspace(-1.5, 1.5, 5))
    plt.yticks(np.linspace(-1.5, 1.5, 5))
    # plt.axis("equal")


plot_F_features(freq, coef)


And the DFT version

In [None]:
import colorcet as cc
import matplotlib.pyplot as plt
import numpy as np
import scipy
import scipy.fft
import scipy.stats
import tifffile
from mpl_toolkits import mplot3d

# im = lr.squeeze().numpy()[0, :, : ]
im = lr.numpy().squeeze()

arr = im
cell_size = (
    20  # m / pixel # We are working with rasters that cover a specific size area
)
original_shape = arr.shape

# plt.imshow(arr, cmap=cc.cm.CET_L1); plt.title("Input Array")


In [None]:
fast_y = scipy.fft.next_fast_len(arr.shape[0])
fast_x = scipy.fft.next_fast_len(arr.shape[1])
while fast_y % 2 > 0:  # It is useful later on that we get an even number
    fast_y = scipy.fft.next_fast_len(fast_y + 1)
while fast_x % 2 > 0:
    fast_x = scipy.fft.next_fast_len(fast_x + 1)

print(f"Original shape (y, x): {original_shape}")
print(f"Next even efficient sizes (y, x): {(fast_y, fast_x)}")

pad_arr = np.pad(
    arr,
    pad_width=((0, fast_y - arr.shape[0]), (0, fast_x - arr.shape[1])),
    constant_values=0,
)
print(f"New padded array shape (y, x): {pad_arr.shape}")
# print(f"New padded array spatial extent (y, x): {(pad_arr.shape[0] * cell_size, pad_arr.shape[1] * cell_size)}m") # irrelevant

plt.imshow(pad_arr, cmap=cc.cm.CET_L1)
plt.title("Padded input image")
plt.axhline(arr.shape[0], c="r")
plt.axvline(arr.shape[1], c="r")
plt.show()


In [None]:
F_arr = scipy.fft.fft2(arr)

def make_F_plots(in_arr, ax_args={}, shift=True):
    """We will reuse this plot layout several times
    in_arr: input array
    ax_args: Axes arguements shared using a single dictionary
    shift: Used to specify if in_arr is in an un-shifted FFT domain
    """

    ax_args = {
        **ax_args,
        "cmap": "bwr",
        "vmin": 0,
        "vmax": np.log(1 + np.abs(in_arr) ** 2).max(),
    }  # Spoiler alert - this is just to set a common colour map.
    fftshift = (
        scipy.fft.fftshift if shift else lambda i: i
    )  # Whether to shift the domain to DC at origin or do nothing
    extent = ax_args.get("extent")

    plt.figure(figsize=(20, 5))
    F_magnitude_spectrum = np.abs(in_arr)  # Fourier Magnitude Spectrum
    plt.subplot(1, 4, 1)
    plt.title("Input Magnitude Spectrum")
    plt.imshow(np.log(1 + F_magnitude_spectrum), **ax_args)
    plt.colorbar(orientation="horizontal")
    if extent:
        plt.ylim(extent[2], extent[3])
        plt.xlim(extent[0], extent[1])
        plt.axis("equal")

    plt.subplot(1, 4, 2)
    plt.title(("Shifted " if shift else "") + "Magnitude Spectrum")
    plt.imshow(
        np.log(1 + fftshift(F_magnitude_spectrum)), **ax_args
    )  # Easier to interpret Fourier Magnitude Spectrum
    plt.colorbar(orientation="horizontal")
    if extent:
        plt.ylim(extent[2], extent[3])
        plt.xlim(extent[0], extent[1])
        plt.axis("equal")

    F_power_spectrum = np.abs(in_arr) ** 2  # Fourier Power Spectrum
    plt.subplot(1, 4, 3)
    plt.title(("Shifted " if shift else "") + "Power Spectrum")
    plt.imshow(np.log(1 + fftshift(F_power_spectrum)), **ax_args)
    plt.colorbar(orientation="horizontal", label="Amplitude")
    if extent:
        plt.ylim(extent[2], extent[3])
        plt.xlim(extent[0], extent[1])
        plt.axis("equal")

    F_phase = np.angle(in_arr) / np.pi  # Fourier Phase Spectrum
    plt.subplot(1, 4, 4)
    plt.title(("Shifted " if shift else "") + "Phase")
    plt.imshow(fftshift(F_phase), cmap=ax_args["cmap"], vmin=-1, vmax=1)
    plt.colorbar(orientation="horizontal", label="$\pi$")
    if extent:
        plt.ylim(extent[2], extent[3])
        plt.xlim(extent[0], extent[1])
        plt.axis("equal")

freqs = scipy.fft.fftshift(scipy.fft.fftfreq(pad_arr.shape[0], 1))
make_F_plots(F_arr, {"extent": [freqs[0], freqs[-1], freqs[0], freqs[-1]]})
