# Load Raw Images

## Compare hrHSI and Snapshot

### Load raw HrHSI

In [3]:
from preprocessing import preprocessHrHSI

# Specify paths
hr_path = "data/raw/FX10/corn_m3"
hr_mtx_path = 'data/raw/calibration/hrHSI_matrix.npy'
hr_dist_path = 'data/raw/calibration/hrHSI_dist.npy'

# Load and preprocess file
hr_file = preprocessHrHSI(hr_path + ".hdf5", hr_mtx_path, hr_dist_path)
hr_img = hr_file[0]
hr_wavelengths = hr_file[1]

In [None]:
# Preview hrHSI
from preprocessing import previewHrHSI
previewHrHSI(hr_img, hr_wavelengths, 
             selected_pixel=(500, 500),
             selected_spectrum=200)

### Load raw Snapshot

In [5]:
from preprocessing import preprocessSnapshot

# Specify paths
ss_path = "data/raw/Snapshot/processed/train/corn_m3/"
ss_mtx_path = 'data/raw/calibration/snapshot_matrix.npy'
ss_dist_path = 'data/raw/calibration/snapshot_dist.npy'
ss_x_off = (0, 405)
ss_y_off = (0, 214)

# Load and preprocess file
ss_file = preprocessSnapshot(ss_path, ss_mtx_path, ss_dist_path, ss_x_off, ss_y_off)
ss_img = ss_file[0]
ss_wavelengths = ss_file[1]

In [None]:
# Preview snapshot
from preprocessing import previewSnapshot
previewSnapshot(ss_img, ss_wavelengths, 
                 selected_pixel=(115, 150),
                 selected_spectrum=12)

ss_img.shape

### Compare images

In [None]:
from align_images import align_images, plot_image_comparison

# Set offsets and rotations
hr_x_off = (106, 1018)   # Full: (0, 1084) # Previous: (104, 1019)
hr_y_off = (189, 720)   # Full: (0, 1015) # Previous: (189, 720)
rot = -0.2
shear = 0.017
# rot = 0
# shear = 0

ss_x_off = (3, 402)
ss_y_off = (2, 212)


# Align images
hr_img_al, ss_img_al = align_images(hr_img=hr_img,
                                    ss_img=ss_img, 
                                    ss_x_off=ss_x_off,
                                    ss_y_off=ss_y_off,
                                    hr_x_off=hr_x_off, 
                                    hr_y_off=hr_y_off,
                                    rot=rot,
                                    shear=shear)




# Plot image comparison
plot_image_comparison(hr_img=hr_img_al, 
                      hr_wavelengths=hr_wavelengths, 
                      ss_img=ss_img_al, 
                      ss_wavelengths=ss_wavelengths, 
                      selected_pixel=(135, 235),
                      selected_spectrum=715)

#### Compare aspect ratios

In [None]:
# Compare resolutions
print("HrHSI:", hr_img_al.shape)
print("Shapshot:", ss_img_al.shape)
print("\n")

# Compare aspect ratio
print("HrHSI:", round(hr_img_al.shape[1] / hr_img_al.shape[0], 6))
print("Shapshot:", round(ss_img_al.shape[1] / ss_img_al.shape[0], 6))

# Difference between aspect ratios
print("Difference:", round((hr_img_al.shape[1] / hr_img_al.shape[0]) - (ss_img_al.shape[1] / ss_img_al.shape[0]), 6))

# Full Preprocessing

## Load hrHSI and go through full preprocessing pipeline

In [None]:
from preprocessing import preprocessFullHSI
import matplotlib.pyplot as plt

# Specify paths
hr_path = 'data/raw/FX10/pumpkin_s2.hdf5'
hr_mtx_path = 'data/raw/calibration/hrHSI_matrix.npy'
hr_dist_path = 'data/raw/calibration/hrHSI_dist.npy'

# Snapshot dimensions and wavelengths
ss_shape = (210, 210, 24) # Square
ss_wavelengths = [667, 679, 691, 703, 715, 
                  727, 739, 751, 763, 775, 
                  787, 799, 811, 823, 835, 
                  847, 859, 871, 883, 895, 
                  907, 919, 931, 943]


# Set offsets and rotations
hr_x_off = (110, 590)  # Square
hr_y_off = (191, 719)  # Square
rot = -0.3
shear = 0.02

# Load and preprocess file
hr_file = preprocessFullHSI(path_to_hdf5=hr_path, 
                            mtx_path=hr_mtx_path, 
                            dist_path=hr_dist_path, 
                            hr_x_off=hr_x_off, 
                            hr_y_off=hr_y_off, 
                            rot=rot, 
                            shear=shear, 
                            ss_shape=ss_shape, 
                            ss_wavelengths=ss_wavelengths)
hr_img = hr_file[0]
hr_wavelengths = hr_file[1]

# Display image
plt.imshow(hr_img[:, :, 3])
plt.show()

print(hr_img.shape)

## Load snapshot and preprocess

In [None]:
from preprocessing import preprocessSnapshot
import matplotlib.pyplot as plt

# Specify paths
ss_path = "data/raw/Snapshot/processed/train/pumpkin_s2/"
ss_mtx_path = 'data/raw/calibration/snapshot_matrix.npy'
ss_dist_path = 'data/raw/calibration/snapshot_dist.npy'

# Load and preprocess file
ss_x_off = (3, 213)
ss_y_off = (2, 212)

ss_file = preprocessSnapshot(ss_path, ss_mtx_path, ss_dist_path, ss_x_off, ss_y_off)

# Store to variables
ss_img = ss_file[0]
ss_wavelengths = ss_file[1]


plt.imshow(ss_img[:, :, 3])
plt.show()

print(ss_img.shape)

## Show side by side comparison

In [None]:
from align_images import plot_image_comparison

# Plot image comparison
plot_image_comparison(hr_img=hr_img, 
                      hr_wavelengths=hr_wavelengths, 
                      ss_img=ss_img, 
                      ss_wavelengths=ss_wavelengths, 
                      selected_pixel=(100, 110),
                      selected_spectrum=800)

# Compare fully-processed files

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# List of all grains
grains = ["spelt", "rye", "flax", "wheatgrass", "pumpkin", "sunflower", 
          "flaxb", "buckwheat", "millet", "barley", "mix", "corn"]

# Select image
image_id = "pumpkin_s2"

train_img = np.load("data/processed/full_hsi/lr/" + image_id + ".npy")
val_img = np.load("data/processed/full_hsi/hr/" + image_id + ".npy")

# Plot
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].axis('off')
axs[1].axis('off')
axs[0].imshow(train_img[:, :, 12], cmap="gray")
axs[1].imshow(val_img[:, :, 12], cmap="gray")
axs[0].set_title("LrHSI")
axs[1].set_title("HrHSI")
plt.show()

# Print aspect ratios
print("Train:", round(train_img.shape[1] / train_img.shape[0], 6))
print("Val:", round(val_img.shape[1] / val_img.shape[0], 6))

# Print shapes
print("Train:", train_img.shape)
print("Val:", val_img.shape)

# .npy Image Viewer

In [None]:
import numpy as np
from ipywidgets import interact, IntSlider
import matplotlib.pyplot as plt

hr_wavelengths = [666, 679, 690, 704, 715, 
                  726, 739, 750, 764, 775, 
                  786, 800, 811, 822, 835, 
                  846, 860, 871, 882, 896, 
                  907, 918, 932, 943]

# Define functions
def view_two_npy_images(channel=0):

    # Display the images
    fig, axs = plt.subplots(1, 3, figsize=(16, 6))
    fig.suptitle(f"img: {img_id}, \n model: {model}", fontsize=16)
    axs[0].imshow(hr_img[:, :, channel], vmin=0, vmax=1)
    axs[0].set_title(f'HR Image - Channel: {channel}')
    axs[0].axis('off')
    
    axs[1].imshow(sr_img[:, :, channel], vmin=0, vmax=1)
    axs[1].set_title(f'SR Image - Channel: {channel}')
    axs[1].axis('off')

    # Difference between images
    im = axs[2].imshow(difference_image[:, :, channel], cmap='bwr', vmin=-1, vmax=1)
    axs[2].set_title(f'Difference Image - Channel: {channel}')
    axs[2].axis('off')
    fig.colorbar(im, ax=axs[2], orientation='vertical', fraction=0.046, pad=0.04)



def plot_values_per_wavelengths(wavelengths):
    
    # Calculate mean values per wavelength
    mean_values_hr = hr_img.mean(axis=(0, 1))
    mean_values_sr = sr_img.mean(axis=(0, 1))
    
    # Plot the values
    plt.figure(figsize=(5, 5))
    plt.plot(wavelengths, mean_values_hr, label='HR Image')
    plt.plot(wavelengths, mean_values_sr, label='SR Image')
    plt.xlabel('Wavelength (nm)')
    plt.ylabel('Mean Value')
    plt.ylim(0, 1)
    plt.title('Mean Values per Wavelength')
    plt.legend()




# Select image and model
img_id = "corn_l4"
model = "HSI_x2_real"

# Load images
hr_path = "results/c_hr/" + img_id + ".npy"
sr_path = "results/models/" + model + "/sr_synth/" + img_id + ".npy"
hr_img = np.load(hr_path)
sr_img = np.load(sr_path)
difference_image = hr_img - sr_img

# Show images
interact(view_two_npy_images, hr_path=hr_path, sr_path=sr_path, img_id=img_id, model=model, channel=IntSlider(min=0, max=23, step=1, value=0))

# Show mean values per wavelength
plot_values_per_wavelengths(hr_wavelengths)  

# Create Figures

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Select image and model
img_id = "wheatgrass_m4"
model = "HSI_x2_real"
channel = 12
selected_pixel = (375, 230)

# Load images
lr_path = "Real-ESRGAN/inputs/" + img_id + ".npy"
hr_path = "results/c_hr/" + img_id + ".npy"
sr_path = "results/models/" + model + "/sr/" + img_id + ".npy"
bi_path = "results/c_bi_480/" + img_id + ".npy"

# Load the HR, SR, LR, and bicubic images
lr_image = np.load(lr_path)
hr_image = np.load(hr_path)
sr_image = np.load(sr_path)
bi_image = np.load(bi_path)

# Compute the difference image
difference_image = hr_image - sr_image

wavelengths = [666, 679, 690, 704, 715, 
                  726, 739, 750, 764, 775, 
                  786, 800, 811, 822, 835, 
                  846, 860, 871, 882, 896, 
                  907, 918, 932, 943]





# Calculate pixel values per wavelength for a specific pixel
values_hr = hr_image[selected_pixel[0], selected_pixel[1], :]
values_bi = bi_image[selected_pixel[0], selected_pixel[1], :]
values_sr = sr_image[selected_pixel[0], selected_pixel[1], :]

# Plot the values
plt.figure(figsize=(5, 5))
plt.plot(wavelengths, values_hr, label='HR Image')
plt.plot(wavelengths, values_sr, label='SR Image')
plt.axvline(x=wavelengths[channel], color='red', linestyle='--', label=f'Channel {channel}')
plt.xlabel('Wavelength (nm)')
plt.ylabel('Mean Value')
plt.ylim(0, 1)
plt.legend()
plt.tight_layout()
plt.show()


# Display the difference image for a specific channel
plt.imshow(difference_image[:, :, channel], cmap='bwr')
plt.colorbar()
plt.clim(-1, 1)
plt.title(f'Difference Image - Channel: {channel}')
plt.axis('off')
plt.show()


# Save images
# plt.imsave(f"visualizations/{model}/appendix/{img_id}_ch{channel}_lr.png", lr_image[:, :, channel], cmap='gray', vmin=0, vmax=1)
# plt.imsave(f"visualizations/{model}/appendix/{img_id}_ch{channel}_bi.png", bi_image[:, :, channel], cmap='gray', vmin=0, vmax=1)
# plt.imsave(f"visualizations/{model}/appendix/{img_id}_ch{channel}_hr.png", hr_image[:, :, channel], cmap='gray', vmin=0, vmax=1)
# plt.imsave(f"visualizations/{model}/appendix/{img_id}_ch{channel}_sr.png", sr_image[:, :, channel], cmap='gray', vmin=0, vmax=1)
# plt.imsave(f"visualizations/{model}/appendix/{img_id}_ch{channel}_diff.png", difference_image[:, :, channel], cmap='bwr', vmin=-1, vmax=1)
# Add red dot for selected pixel
plt.figure(figsize=(5, 5))
plt.imshow(hr_image[:, :, channel], cmap='gray', vmin=0, vmax=1)
plt.scatter([selected_pixel[1]], [selected_pixel[0]], c='red', s=75)
plt.axis('off')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.show()
# plt.savefig(f"visualizations/{model}/appendix/{img_id}_ch{channel}_hr_with_pixel.png", bbox_inches='tight', pad_inches=0)

plt.figure(figsize=(5, 5))
plt.imshow(hr_image[:, :, channel], cmap='gray', vmin=0, vmax=1)
plt.scatter([selected_pixel[1]], [selected_pixel[0]], c='red', s=75)
plt.axis('off')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.show()
# plt.savefig(f"visualizations/{model}/appendix/{img_id}_ch{channel}_hr_with_pixel.png", bbox_inches='tight', pad_inches=0)

plt.figure(figsize=(4, 4), dpi=300)
plt.plot(wavelengths, values_hr, label='HR Image')
plt.plot(wavelengths, values_sr, label='SR Image')
plt.axvline(x=wavelengths[channel], color='red', linestyle='--', label=f'Channel {channel}')
plt.xlabel('Wavelength (nm)')
plt.ylabel('Pixel Value')
plt.ylim(0, 1)
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(4, 4), dpi=300)
plt.plot(wavelengths, values_sr, label='SR Image')
plt.plot(wavelengths, values_hr, label='HR Image')
plt.plot(wavelengths, values_bi, label='LR Image')
plt.axvline(x=wavelengths[channel], color='red', linestyle='--', label=f'Channel {channel}')
plt.xlabel('Wavelength (nm)')
plt.ylabel('Pixel Value')
plt.ylim(0, 1)
plt.legend()
plt.tight_layout()


# Save image of wavelengths
# plt.savefig(f"visualizations/{model}/appendix/{img_id}_values.png", bbox_inches='tight', pad_inches=0)

# plt.imsave(f"visualizations/{img_id}_ch{channel}_hr.png", hr_image[:, :, channel], cmap='gray', vmin=0, vmax=1)

### Save grayscale images

In [None]:
# Save grayscale images
import numpy as np
import matplotlib.pyplot as plt

img_id = "sunflower_l4"
channel = 12

# Load images
hr_path = "data/processed/full_hsi_val/test/hr/" + img_id + ".npy"
lr_path = "data/processed/full_hsi_val/test/bi_240/" + img_id + ".npy"

# Load the HR and SR images
hr_image = np.load(hr_path)
lr_image = np.load(lr_path)

# Preview grayscale
plt.figure(figsize=(5, 5))
plt.imshow(lr_image[:, :, 12], cmap="gray")
plt.show()

# plt.imsave(f"visualizations/samples/{img_id}_ch{channel}_hr.png", hr_image[:, :, channel], cmap='gray', vmin=0, vmax=1)
# plt.imsave(f"visualizations/samples/{img_id}_ch{channel}_lr.png", lr_image[:, :, channel], cmap='gray', vmin=0, vmax=1)

## Extract data from log files with regex

In [169]:
import re
import pandas as pd

epochs = []
pixel_losses = []
perceptual_losses = []
adversarial_losses = []
psnr_list = []

log_file_path = "Real-ESRGAN/experiments/finetune_HSIx2_val3_synth/train_finetune_HSIx2_val3_synth_20250105_214750.log"

with open(log_file_path, 'r') as file:
    for line in file:
        epoch_match = re.search(r'\[epoch:\s*(\d+)', line)
        pixel_loss_match = re.search(r'l_g_pix:\s*([\d.eE+-]+)', line)
        perceptual_loss_match = re.search(r'l_g_percep:\s*([\d.eE+-]+)', line)
        adversarial_loss_match = re.search(r'l_g_gan:\s*([\d.eE+-]+)', line)
        psnr_match = re.search(r'psnr:\s*([\d.eE+-]+)', line)
        
        if epoch_match:
            # print(f"epoch:{epoch_match.group(1)}, pixel_loss:{float(pixel_loss_match.group(1))}, perceptual_loss:{float(perceptual_loss_match.group(1))}")
            # print("")
            epochs.append(int(epoch_match.group(1)))
            pixel_losses.append(float(pixel_loss_match.group(1)))
            perceptual_losses.append(float(perceptual_loss_match.group(1)))
            adversarial_losses.append(float(adversarial_loss_match.group(1)))
            psnr_list.append(None)
        elif psnr_match:
            epoch_match = epochs[-1]
            epochs.append(epoch_match)
            pixel_losses.append(None)
            perceptual_losses.append(None)
            adversarial_losses.append(None)
            psnr_list.append(float(psnr_match.group(1)))
            # print(f"epoch: {epoch_match} psnr:{float(psnr_match.group(1))}")

# Save to csv file
data = {'epoch': epochs, 'pix_loss': pixel_losses, 'perc_loss': perceptual_losses, 'adv_loss': adversarial_losses, 'psnr': psnr_list}
df = pd.DataFrame(data)
df.to_csv('results/losses/finetune_HSIx2_val3_synth_losses5.csv', index=False)