Notebook author: Elena Gronskaya

The purpose of this notebook is to adjust the functions of the original ISR repo to make them compatible for working with raw TIF files sourced from Google Earth Engine. If using PNG satellite images (pre-processed with GDAL_transformer_PNG.ipynb), use ISR_module_adjustments_PNG.ipynb.

Usage: the cells in this notebook should be copied to the training/prediction notebooks and ran at the beginning (if ISR is imported through pip install). Alternatively, these changes can be applied to a local ISR repo.

In [None]:
!pip install gast"==0.3.2"
!pip install ISR
!pip install 'h5py==2.10.0' --force-reinstall

In [None]:
import ISR
import numpy as np
import ISR.utils.datahandler
import ISR.predict.predictor

In [None]:
def _make_img_list(self):
  """ Creates a dictionary of lists of the acceptable images contained in lr_dir and hr_dir. """
        
  for res in ['hr', 'lr']:
    file_names = os.listdir(self.folders[res])
    file_names = [file for file in file_names]
    self.img_list[res] = np.sort(file_names)
        
  if self.n_validation_samples:
    samples = np.random.choice(
        range(len(self.img_list['hr'])), self.n_validation_samples, replace=False
        )
    for res in ['hr', 'lr']:
      self.img_list[res] = self.img_list[res][samples]

ISR.utils.datahandler.DataHandler._make_img_list = _make_img_list  

In [None]:
def get_batch(self, batch_size, idx=None, flatness=0.0):
  """
  Returns a dictionary with keys ('lr', 'hr') containing training batches
  of Low Res and High Res image patches.
  Args:
      batch_size: integer.
      flatness: float in [0,1], is the patch "flatness" threshold.
    Determines what level of detail the patches need to meet. 0 means any patch is accepted.
  """
        
  if not idx:
      # randomly select one image. idx is given at validation time.
    idx = np.random.choice(range(len(self.img_list['hr'])))
  img = {}
  for res in ['lr', 'hr']:
    img_path = os.path.join(self.folders[res], self.img_list[res][idx])

    # different normalization for landsat and sentinel images
    if res == 'lr':
      img[res] = ((imageio.imread(img_path).astype(int)*0.0000275-0.2)*255*4).astype(int) #landsat
    else:
      img[res] = ((imageio.imread(img_path).astype(int)*255/3558)*1.4).astype(int) #sentinel

    img[res][img[res]>255] = 255
    img[res][img[res]<0] = 0
    img[res] = img[res] / 255.0

  batch = self._crop_imgs(img, batch_size, flatness)
  transforms = np.random.randint(0, 3, (batch_size, 2))
  batch['lr'] = self._transform_batch(batch['lr'], transforms)
  batch['hr'] = self._transform_batch(batch['hr'], transforms)
          
  return batch

ISR.utils.datahandler.DataHandler.get_batch = get_batch  

In [None]:
import imageio
from pathlib import Path
from ISR.utils.logger import get_logger
import yaml

def _forward_pass(self, file_path):
  lr_img = ((imageio.imread(file_path).astype(int)*0.0000275-0.2)*255*4).astype(int)
  lr_img[lr_img>255] = 255
  lr_img[lr_img<0] = 0

  if lr_img.shape[2] == 3:
    sr_img = self.model.predict(lr_img)
    return sr_img
  else:
    self.logger.error('{} is not an image with 3 channels.'.format(file_path))

ISR.predict.predictor.Predictor._forward_pass = _forward_pass

def predictor_init(self, input_dir, output_dir='./data/output', verbose=True):

  self.input_dir = Path(input_dir)
  self.data_name = self.input_dir.name
  self.output_dir = Path(output_dir) / self.data_name
  self.logger = get_logger(__name__)
  if not verbose:
    self.logger.setLevel(40)
  self.extensions = ('.jpeg', '.jpg', '.png','.tif')  # file extensions that are admitted
  self.img_ls = [f for f in self.input_dir.iterdir() if f.suffix in self.extensions]
  if len(self.img_ls) < 1:
    self.logger.error('No valid image files found (check config file).')
    raise ValueError('No valid image files found (check config file).')
  # Create results folder
  if not self.output_dir.exists():
    self.logger.info('Creating output directory:\n{}'.format(self.output_dir))
    self.output_dir.mkdir(parents=True)

ISR.predict.predictor.Predictor.__init__ = predictor_init

# fix for posix path error

def _load_weights(self):
    """ Invokes the model's load weights function if any weights are provided. """
    if self.weights_path is not None:
        self.logger.info('Loaded weights from \n > {}'.format(self.weights_path))
        # loading by name automatically excludes the vgg layers
        self.model.model.load_weights(str(self.weights_path))
    else:
        self.logger.error('Error: Weights path not specified (check config file).')
        raise ValueError('Weights path not specified (check config file).')

    session_config_path = self.weights_path.parent / 'session_config.yml'
    if session_config_path.exists():
        conf = yaml.load(session_config_path.read_text(), Loader=yaml.FullLoader)
    else:
        self.logger.warning('Could not find weights training configuration')
        conf = {}
    conf.update({'pre-trained-weights': self.weights_path.name})
    return conf

ISR.predict.predictor.Predictor._load_weights=_load_weights       

In [None]:
# fix for posix path error
import ISR.train.trainer

def _load_weights(self):
    """
    Loads the pretrained weights from the given path, if any is provided.
    If a discriminator is defined, does the same.
    """
    
    if self.weights_generator:
        self.model.get_layer('generator').load_weights(str(self.weights_generator))
    
    if self.discriminator:
        if self.weights_discriminator:
            self.model.get_layer('discriminator').load_weights(str(self.weights_discriminator))
            self.discriminator.model.load_weights(str(self.weights_discriminator))

ISR.train.trainer.Trainer._load_weights=_load_weights  