Notebook authors: Özgün Haznedar, Elena Gronskaya

The purpose of this notebook is to adjust the functions of the original ISR repo to make them compatible for working with PNG satellite files sourced from Google Earth Engine and pre-processed with GDAL_transformer_PNG.ipynb.

Usage: the cells in this notebook should be copied to the training/prediction notebooks and ran at the beginning (if ISR is imported through pip install). Alternatively, these changes can be applied to a local ISR repo.

In [None]:
!pip install gast"==0.3.2"
!pip install ISR
!pip install 'h5py==2.10.0' --force-reinstall
import ISR

In [None]:
import ISR.utils.datahandler

def datahandler_init(self, lr_dir, hr_dir, patch_size, scale, n_validation_samples=None):
  self.folders = {'hr': hr_dir, 'lr': lr_dir}  # image folders
  self.extensions = ('.png', '.jpeg', '.jpg', '.tif')  # admissible extension
  self.img_list = {}  # list of file names
  self.n_validation_samples = n_validation_samples
  self.patch_size = patch_size
  self.scale = scale
  self.patch_size = {'lr': patch_size, 'hr': patch_size * self.scale}
  self.logger = get_logger(__name__)
  self._make_img_list()
  self._check_dataset()

ISR.utils.datahandler.DataHandler.__init__ = datahandler_init        

In [None]:
import ISR.predict.predictor

# fix for admissible extension
def predictor_init(self, input_dir, output_dir='./data/output', verbose=True):

  self.input_dir = Path(input_dir)
  self.data_name = self.input_dir.name
  self.output_dir = Path(output_dir) / self.data_name
  self.logger = get_logger(__name__)
  if not verbose:
      self.logger.setLevel(40)
  self.extensions = ('.jpeg', '.jpg', '.png', '.tif')  # file extensions that are admitted
  self.img_ls = [f for f in self.input_dir.iterdir() if f.suffix in self.extensions]
  if len(self.img_ls) < 1:
      self.logger.error('No valid image files found (check config file).')
      raise ValueError('No valid image files found (check config file).')
  # Create results folder
  if not self.output_dir.exists():
      self.logger.info('Creating output directory:\n{}'.format(self.output_dir))
      self.output_dir.mkdir(parents=True)
ISR.predict.predictor.Predictor.__init__ = predictor_init

# fix for posix path error

def _load_weights(self):
    """ Invokes the model's load weights function if any weights are provided. """
    if self.weights_path is not None:
        self.logger.info('Loaded weights from \n > {}'.format(self.weights_path))
        # loading by name automatically excludes the vgg layers
        self.model.model.load_weights(str(self.weights_path))
    else:
        self.logger.error('Error: Weights path not specified (check config file).')
        raise ValueError('Weights path not specified (check config file).')

    session_config_path = self.weights_path.parent / 'session_config.yml'
    if session_config_path.exists():
        conf = yaml.load(session_config_path.read_text(), Loader=yaml.FullLoader)
    else:
        self.logger.warning('Could not find weights training configuration')
        conf = {}
    conf.update({'pre-trained-weights': self.weights_path.name})
    return conf

ISR.predict.predictor.Predictor._load_weights=_load_weights      

In [None]:
import ISR.train.trainer

# fix for posix path error
def _load_weights(self):
    """
    Loads the pretrained weights from the given path, if any is provided.
    If a discriminator is defined, does the same.
    """
    
    if self.weights_generator:
        self.model.get_layer('generator').load_weights(str(self.weights_generator))
    
    if self.discriminator:
        if self.weights_discriminator:
            self.model.get_layer('discriminator').load_weights(str(self.weights_discriminator))
            self.discriminator.model.load_weights(str(self.weights_discriminator))

ISR.train.trainer.Trainer._load_weights=_load_weights  