In [None]:
from google.colab import drive
import numpy as np
import matplotlib.pyplot as plt
import glob
import cv2 as cv
from osgeo import gdal
drive.mount('/content/drive/')

min_dim = 224

Mounted at /content/drive/


In [None]:
#HELPER FUNCTIONS
def chl(b1,b2,b3):
  #chl coeff from NASA, units are mg/m^3
  #note that this returns log(chl), which is the input needed for boa
  a0 = 0.2412
  a1 = -2.0546
  a2 = 1.1776
  a3 = -0.5538
  a4 = -0.4570

  chl = np.zeros((min_dim,min_dim))
  for i in range(min_dim):
    for j in range(min_dim):
      blue = max(b1[i][j],b2[i][j])
      green = b3[i][j]
      chl[i][j] = a0 + a1*(np.log10(blue/green)) + a2*(np.log10(blue/green))**2 + a3*(np.log10(blue/green))**3 + a4*(np.log10(blue/green))**4
  return chl

def sst(b10):
  #sst coeff from Landsat, units are C
  b10 = b10*0.00341802+149-273.15
  return b10

#HELPER FUNCTIONS FOR BOA
def median_filter(data):
    return np.median(np.ndarray.flatten(data))

def peak_max(data, filter_size):
    indexer = filter_size // 2
    WE = np.argmax(data[indexer])
    NS = np.argmax(np.transpose(data)[indexer])
    NWSE = np.argmax(data.diagonal())
    NESW = np.argmax(data[:,::-1].diagonal())
    return {indexer} == {WE,NS,NWSE,NESW}

def peak_min(data, filter_size):
    indexer = filter_size // 2
    WE = np.argmin(data[indexer])
    NS = np.argmin(np.transpose(data)[indexer])
    NWSE = np.argmin(data.diagonal())
    NESW = np.argmin(data[:,::-1].diagonal())
    return {indexer} == {WE,NS,NWSE,NESW}

def peak_masks(data, filter_size):
    peaks = np.zeros(data.shape)
    indexer = filter_size // 2
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            test = data[i-indexer:i+indexer+1,j-indexer:j+indexer+1]
            if test.shape == (filter_size,filter_size):
                if peak_max(test,filter_size) or peak_min(test,filter_size):
                    peaks[i][j] = 1
    return(peaks)

def boa(img):
    peak_5 = peak_masks(img,5)
    peak_3 = peak_masks(img,3)
    new_img = np.zeros(img.shape)
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            if not peak_5[i][j] and peak_3[i][j]:
                new_img[i][j] = median_filter(img[i-1:i+2,j-1:j+2])
            else:
                new_img[i][j] = img[i][j]
    return new_img

def create_ground_truth(img):
  old_step = img
  new_step = boa(img)
  while not np.array_equal(old_step,new_step):
      old_step = new_step
      new_step = boa(old_step)
  img = new_step
  img = np.float32(img)
  gx = cv.Sobel(img, cv.CV_32F, 1, 0, ksize=3)
  gy = cv.Sobel(img, cv.CV_32F, 0, 1, ksize=3)
  sobel_final = cv.addWeighted(gx, 0.5, gy, 0.5, 0)
  return sobel_final

def random_crop(dim):
  return np.random.randint(0,dim-224)

In [None]:
#this is set up to create ground truth from Indian Ocean data--can change filenames to create ground truth for other oceans
for file in glob.glob("/content/drive/My Drive/indian/*.tif"):
  #crop images to same size
  dataset = gdal.Open(file)
  x_crop = random_crop(dataset.RasterXSize)
  y_crop = random_crop(dataset.RasterYSize)

  SR_B1 = dataset.GetRasterBand(1).ReadAsArray()[y_crop:y_crop+224,x_crop:x_crop+224]
  SR_B2 = dataset.GetRasterBand(2).ReadAsArray()[y_crop:y_crop+224,x_crop:x_crop+224]
  SR_B3 = dataset.GetRasterBand(3).ReadAsArray()[y_crop:y_crop+224,x_crop:x_crop+224]
  SR_B4 = dataset.GetRasterBand(4).ReadAsArray()[y_crop:y_crop+224,x_crop:x_crop+224]
  ST_B10 = dataset.GetRasterBand(5).ReadAsArray()[y_crop:y_crop+224,x_crop:x_crop+224]

  #create ground truths
  boa_chl = create_ground_truth(chl(SR_B1,SR_B2,SR_B3))
  boa_sst = create_ground_truth(sst(ST_B10))

  output = np.stack([boa_chl,boa_sst], axis=2)
  np.save('/content/drive/My Drive/indian_train/' + file.split("/")[-1][:-4] + '_output.npy', output)

  plt.imsave('/content/drive/My Drive/indian_viz/' + file.split("/")[-1][:-4] + '_output_chl.png', boa_chl, vmin=-0.1, vmax=0.1, cmap='seismic')
  plt.imsave('/content/drive/My Drive/indian_viz/' + file.split("/")[-1][:-4] + '_output_sst.png', boa_sst, vmin=-1, vmax=1, cmap='seismic')

  #create scaled input data
  B1 = (dataset.GetRasterBand(6).ReadAsArray()[y_crop:y_crop+224,x_crop:x_crop+224]/32727.5*255)
  plt.imsave('/content/drive/My Drive/indian_viz/' + file.split("/")[-1][:-4] + '_input_B1.png', B1, vmin = 60, vmax = 110, cmap='ocean')
  B2 = (dataset.GetRasterBand(7).ReadAsArray()[y_crop:y_crop+224,x_crop:x_crop+224]/32727.5*255)
  B3 = (dataset.GetRasterBand(8).ReadAsArray()[y_crop:y_crop+224,x_crop:x_crop+224]/32727.5*255)
  B4 = (dataset.GetRasterBand(9).ReadAsArray()[y_crop:y_crop+224,x_crop:x_crop+224]/32727.5*255)
  B10 = (dataset.GetRasterBand(10).ReadAsArray()[y_crop:y_crop+224,x_crop:x_crop+224]/32727.5*255)
  plt.imsave('/content/drive/My Drive/indian_viz/' + file.split("/")[-1][:-4] + '_input_B10.png', B10, vmin=130, vmax=230, cmap='ocean')
  
  input = np.stack([B1,B2,B3,B4,B10], axis=2)
  np.save('/content/drive/My Drive/indian_train/' + file.split("/")[-1][:-4] + '_input.npy', input)