<a href="https://colab.research.google.com/github/astroChance/RadNET/blob/master/Post_DL_workbook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports and Drive Mount

In [None]:
!pip install --upgrade segyio

In [None]:
import numpy as np
from math import sqrt
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import StrMethodFormatter
import segyio
import time
import itertools
import os
import re
import random
from PIL import Image
from scipy import spatial, signal
import json
from sklearn.preprocessing import normalize, StandardScaler

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
### DL prediction images
predict_path = "/content/drive/My Drive/RadNET/GAN data/Dev_Testing_Data/Test_Predictions/"
InferenceFiles = []
for root, dirs, files in os.walk(predict_path):
    for name in files:
        filename = os.path.join(root, name)
        InferenceFiles.append(filename)

## Corresponding 2D (target) and 3D (input) images
_2D_image_path = "/content/drive/My Drive/RadNET/GAN data/Dev_Testing_Data/2D_testing"
ImageFiles_2D = []
for root, dirs, files in os.walk(_2D_image_path):
    for name in files:
        filename = os.path.join(root, name)
        ImageFiles_2D.append(filename)

_3D_image_path = "/content/drive/My Drive/RadNET/GAN data/Dev_Testing_Data/3D_testing"
ImageFiles_3D = []
for root, dirs, files in os.walk(_3D_image_path):
    for name in files:
        filename = os.path.join(root, name)
        ImageFiles_3D.append(filename)

In [None]:
### SEGY Data filenames
# 3D

Volume3D = "/content/drive/My Drive/RadNET/GAN data/3D/PB3D_Fritz_subset.sgy"

# 2D

_2D_segy_path = "/content/drive/My Drive/RadNET/GAN data/2D"
TwoDFiles = []
for root, dirs, files in os.walk(_2D_segy_path):
    for name in files:
        filename = os.path.join(root, name)
        TwoDFiles.append(filename)

# Custom Functions

In [None]:
########
# Griffin Lim implementation

def griffin_lim(magnitude, iterations, orig_sig, fs, nperseg, noverlap, window):
    """
    Inputs:
      magnitude: the real-value spectrogram array to be converted to signal
      iterations: number of iterations to perform Griffin Lim, suggest several hundred -> 1000
      orig_sig: the original time-domain signal prior to enhancement
            -typical GLA uses random initial signal, using original signal improves stability
      fs, nperseg, noverlap, window: parameters from original STFT


    Returns:
      GLA reconstructed time domain signal

    """
    sig_recon = orig_sig

    error = []

    while iterations > 0:
        _, _, temp_spec = signal.stft(sig_recon, window = window, fs = fs, nperseg = nperseg, noverlap=noverlap)
        comp_angle = np.angle(temp_spec)
        new_spec = magnitude*np.exp(1j*comp_angle)
        prev_sig = sig_recon
        _, sig_recon = signal.istft(new_spec, window = window, fs = fs, nperseg = nperseg, noverlap=noverlap)

        if iterations % 10 == 0:
            try:
                rmse = sqrt(sum((sig_recon - prev_sig)**2) / prev_sig.size)
                error.append(rmse)
            except ValueError:
                if len(sig_recon) > len(prev_sig):
                    sig_recon_tmp = sig_recon[:len(prev_sig)]
                    rmse = sqrt(sum((sig_recon_tmp - prev_sig)**2) / prev_sig.size)
                    error.append(rmse)
                if len(sig_recon) < len(prev_sig):
                    prev_sig_tmp = prev_sig[:len(sig_recon)]
                    rmse = sqrt(sum((sig_recon - prev_sig_tmp)**2) / prev_sig_tmp.size)
                    error.append(rmse)
        # print("\rCurrent error for iteration ", iterations,": ", rmse, end='', flush=True)
        
        iterations -= 1

    return sig_recon, error





########
# Return the 0-1 data value corresponding to RGB color

def get_val_color(RGB_val, cmap):
  colors = cmap(np.arange(0,cmap.N))

  idx_arr = np.array([])
  for i in range(colors.shape[0]):
    idx_arr = np.append(idx_arr, (np.abs(RGB_val[0]-colors[i][0]) + np.abs(RGB_val[1]-colors[i][1]) + np.abs(RGB_val[2]-colors[i][2])))

    color_idx = idx_arr.argmin()
    value = color_idx/cmap.N
  return value

# Compare sample of predictions to 2D and 3D images

# Convert prediction images back to magnitude spectra arrays

In [None]:
## Parameters from image generation

colormap = 'jet'
min_val = -10
max_val = 0

In [None]:
## remember images were displayed as log(absval(STFT output))

## once values are grabbed from color, use np.exp(value)

# Run GLA on predicted magnitude spectra

In [None]:
## Parameters from initial STFT

fs = 1 / 0.0000000375     
nperseg = 128        
noverlap = 96
window = signal.hann(nperseg, sym=False)