In [None]:
from __future__ import print_function
import numpy as np
import math
import scipy
import pandas as pd
import PIL
import gdal
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import sys, os
from pathlib import Path
import time
import xml.etree.ElementTree as ET
import random
import collections, functools, operator
import csv

import ee

from osgeo import gdal,osr
from gdalconst import *
import subprocess
from osgeo.gdalconst import GA_Update

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.autograd import Variable
from torch.nn import Linear, ReLU, CrossEntropyLoss, MSELoss, Sequential, Conv2d, MaxPool2d, Module, Softmax, BatchNorm2d, Dropout, Sigmoid
from torch.optim import Adam, SGD
from torchvision import transforms, utils

import skimage
from skimage import io, transform
import sklearn
import sklearn.metrics
from sklearn.feature_extraction import image
from sklearn import svm

# Setting up model option and read dataset

## Model options

In [None]:
exp_to_run = 'CiudadReal_SAMGAN'

In [None]:
def train(train_loader, dirModel=None):
    if 'CNN' in run_experiment:
        if dirModel:
            loaded_state = torch.load(dirModel)
            net.load_state_dict(loaded_state["net"])
            optimizer.load_state_dict(loaded_state["net_opt"])
        CNNtrain(train_loader)
    elif 'P2P' in run_experiment or 'iPAN' in run_experiment or 'SAMGAN' in run_experiment:
        if dirModel:
            loaded_state = torch.load(dirModel)
            gen.load_state_dict(loaded_state["gen"])
            gen_opt.load_state_dict(loaded_state["gen_opt"])
            disc.load_state_dict(loaded_state["disc"])
            disc_opt.load_state_dict(loaded_state["disc_opt"])
        P2Ptrain(train_loader)

def test(test_loader, dirModel=None, train_loader=None, saveMetrics=None, svc=None):
    #svc = None
    #if train_loader:
    #    svc = svmClassifier()
    #    svc.train(train_loader)
    
    if 'CNN' in run_experiment:
        if dirModel:
            loaded_state = torch.load(dirModel)
            net.load_state_dict(loaded_state["net"])
            optimizer.load_state_dict(loaded_state["net_opt"])
        CNNtest(test_loader, vizImages=True, svc=svc, saveMetrics=saveMetrics)
    elif 'P2P' in run_experiment or 'iPAN' in run_experiment or 'SAMGAN' in run_experiment:
        if dirModel:
            loaded_state = torch.load(dirModel)
            gen.load_state_dict(loaded_state["gen"])
            gen_opt.load_state_dict(loaded_state["gen_opt"])
            disc.load_state_dict(loaded_state["disc"])
            disc_opt.load_state_dict(loaded_state["disc_opt"])
        P2Ptest(test_loader, vizImages=True, svc=svc, saveMetrics=saveMetrics)

In [None]:
experiments = {'CiudadReal_CNN': 'CiudadReal_CNN',
              'CiudadReal_P2P': 'CiudadReal_P2P',
              'CiudadReal_iPAN': 'CiudadReal_iPAN',
              'CiudadReal_SAMGAN': 'CiudadReal_SAMGAN',
              'California_CNN': 'California_CNN',
              'California_P2P': 'California_P2P',
              'California_iPAN': 'California_iPAN',
              'California_SAMGAN': 'California_SAMGAN'}


run_experiment = experiments[exp_to_run]



experimentDirs = {}
dirs = {} # Must be a list even if there is only one scene directory
dirs['input'] = os.getcwd() + '/drive/My Drive/TFG/Ciudad Real/EO-1 ALI Processed Aligned/'
dirs['input'] = [os.path.join(dirs['input'], f, f[:-5]) for f in sorted(os.listdir(dirs['input']))][::-1]
dirs['target'] = os.getcwd() + '/drive/My Drive/TFG/Ciudad Real/EO-1 Hyperion/'
dirs['target'] = [os.path.join(dirs['target'], f, f[:-5]) for f in sorted(os.listdir(dirs['target']))][::-1]
experimentDirs['CiudadReal_CNN'] = {}
experimentDirs['CiudadReal_CNN']['input'] = dirs['input']
experimentDirs['CiudadReal_CNN']['target'] = dirs['target']
experimentDirs['CiudadReal_P2P'] = {}
experimentDirs['CiudadReal_P2P']['input'] = dirs['input']
experimentDirs['CiudadReal_P2P']['target'] = dirs['target']
experimentDirs['CiudadReal_iPAN'] = {}
experimentDirs['CiudadReal_iPAN']['input'] = dirs['input']
experimentDirs['CiudadReal_iPAN']['target'] = dirs['target']
experimentDirs['CiudadReal_SAMGAN'] = {}
experimentDirs['CiudadReal_SAMGAN']['input'] = dirs['input']
experimentDirs['CiudadReal_SAMGAN']['target'] = dirs['target']

dirs['input'] = [os.getcwd() + '/drive/My Drive/TFG/California/EO-1 ALI/EO1A0420352016053110K2_1GST/EO1A0420352016053110K2/']
dirs['target'] = [os.getcwd() + '/drive/My Drive/TFG/California/EO-1 Hyperion/EO1H0420352016053110K2/']
dirs['crop'] = os.getcwd() + '/drive/My Drive/TFG/California/USA NASS Cropland Data Layer/USA_NASS_CDL_CALIFORNIA2016.tif'
experimentDirs['California_CNN'] = {}
experimentDirs['California_CNN']['input'] = dirs['input']
experimentDirs['California_CNN']['target'] = dirs['target']
experimentDirs['California_CNN']['crop'] = dirs['crop']
experimentDirs['California_P2P'] = {}
experimentDirs['California_P2P']['input'] = dirs['input']
experimentDirs['California_P2P']['target'] = dirs['target']
experimentDirs['California_iPAN'] = {}
experimentDirs['California_iPAN']['input'] = dirs['input']
experimentDirs['California_iPAN']['target'] = dirs['target']
experimentDirs['California_SAMGAN'] = {}
experimentDirs['California_SAMGAN']['input'] = dirs['input']
experimentDirs['California_SAMGAN']['target'] = dirs['target']
#experimentDirs['California_P2P']['crop'] = dirs['crop']


normalization_options = {'CiudadReal_CNN': NormalizeOptions(norm_type='fullRange_norm'),
                        'CiudadReal_P2P': NormalizeOptions(norm_type='fullRange_norm'),
                        'CiudadReal_iPAN': NormalizeOptions(norm_type='fullRange_norm'),
                        'CiudadReal_SAMGAN': NormalizeOptions(norm_type='fullRange_norm'),
                        'California_CNN': NormalizeOptions(norm_type='fullRange_norm'),
                        'California_P2P': NormalizeOptions(norm_type='fullRange_norm'),
                        'California_iPAN': NormalizeOptions(norm_type='fullRange_norm'),
                        'California_SAMGAN': NormalizeOptions(norm_type='fullRange_norm')}

patches_options = {'CiudadReal_CNN': ToPatchesOptions(window=64, step=64, evadeClouds=False, miniPatches=True), # Max window size is 256 because of hyperion width
                    'CiudadReal_P2P': ToPatchesOptions(window=64, step=64, evadeClouds=False, miniPatches=False),
                    'CiudadReal_iPAN': ToPatchesOptions(window=64, step=64, evadeClouds=False, miniPatches=False),
                    'CiudadReal_SAMGAN': ToPatchesOptions(window=64, step=64, evadeClouds=False, miniPatches=False),
                    'California_CNN': ToPatchesOptions(window=64, step=64, evadeClouds=False, miniPatches=True),
                    'California_P2P': ToPatchesOptions(window=64, step=64, evadeClouds=False, miniPatches=False),
                    'California_iPAN': ToPatchesOptions(window=64, step=64, evadeClouds=False, miniPatches=False),
                    'California_SAMGAN': ToPatchesOptions(window=64, step=64, evadeClouds=False, miniPatches=False)}

readWhileRunning = {'CiudadReal_CNN': False,
                    'CiudadReal_P2P': False,
                    'CiudadReal_iPAN': False,
                    'CiudadReal_SAMGAN': False,
                    'California_CNN': False,
                    'California_P2P': False,
                    'California_iPAN': False,
                    'California_SAMGAN': False}

learning_rate = {'CiudadReal_CNN': 0.005,
                'CiudadReal_P2P': 0.005,
                'CiudadReal_iPAN': 0.005,
                'CiudadReal_SAMGAN': 0.005,
                'California_CNN': 0.0005,
                'California_P2P': 0.005,
                'California_iPAN': 0.005,
                'California_SAMGAN': 0.005}

num_epochs = {'CiudadReal_CNN': 400,
                'CiudadReal_P2P': 400,
                'CiudadReal_iPAN': 400,
                'CiudadReal_SAMGAN': 400,
                'California_CNN': 400,
                'California_P2P': 400,
                'California_iPAN': 400,
                'California_SAMGAN': 400}

batch_size = {'CiudadReal_CNN': 1,
            'CiudadReal_P2P': 1,
            'CiudadReal_iPAN': 1,
            'CiudadReal_SAMGAN': 1,
            'California_CNN': 1,
            'California_P2P': 32,
            'California_iPAN': 32,
            'California_SAMGAN': 32}

display_epoch = {'CiudadReal_CNN': 10,
                'CiudadReal_P2P': 1,
                'CiudadReal_iPAN': 1,
                'CiudadReal_SAMGAN': 1,
                'California_CNN': 10,
                'California_P2P': 1,
                'California_iPAN': 1,
                'California_SAMGAN': 1}





dirs = experimentDirs[run_experiment]
normalization_options = normalization_options[run_experiment]
patches_options = patches_options[run_experiment]
readWhileRunning = readWhileRunning[run_experiment]
readFromPatches = True

learning_rate = learning_rate[run_experiment]
num_epochs = num_epochs[run_experiment]
batch_size = batch_size[run_experiment]
display_epoch = display_epoch[run_experiment]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for dset in dirs.keys():
    print(dset, 'dataset scene directories:', dirs[dset])

In [None]:
if 'CNN' in run_experiment:
    # Defining the model
    net = CNNnet().to(device)
    net = net.apply(weights_init)
    # Defining the optimizer
    optimizer = Adam(net.parameters(), lr=learning_rate)
    # Defining the loss function
    loss_fn = MSELoss().to(device)
elif 'P2P' in run_experiment or 'iPAN' in run_experiment or 'SAMGAN' in run_experiment: 
    # Defining the model
    gen = UNet(9, 170).to(device)
    gen = gen.apply(weights_init)
    disc = Discriminator(9 + 170).to(device)
    disc = disc.apply(weights_init)
    # Defining the optimizer
    gen_opt = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    disc_opt = torch.optim.Adam(disc.parameters(), lr=learning_rate)
    # Defining the loss function
    adv_criterion = nn.BCEWithLogitsLoss()
    recon_criterion = nn.L1Loss()
    target_shape = 64
    lambda_recon = 200
    if 'SAMGAN' in run_experiment:
        lambda_sam = 200

## Call Satellite dataset to read tha data for the appropiate model

In [None]:
if readFromPatches:
    net_mode = 'Training'
    dirs['input'] = f'{os.getcwd()}/drive/My Drive/TFG/Ciudad Real/Patches/Dataset_128BS_64PS_64SS_CoRegistered/{net_mode} Input/'
    dirs['target'] = f'{os.getcwd()}/drive/My Drive/TFG/Ciudad Real/Patches/Dataset_128BS_64PS_64SS_CoRegistered/{net_mode} Target/'
    #train_dataset = SatelliteDataset(dirs, normalization_options, patches_options, readWhileRunning, readFromPatches)
    #train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)

    net_mode = 'Test'
    dirs['input'] = f'{os.getcwd()}/drive/My Drive/TFG/Ciudad Real/Patches/Dataset_128BS_64PS_64SS_CoRegistered/{net_mode} Input/'
    dirs['target'] = f'{os.getcwd()}/drive/My Drive/TFG/Ciudad Real/Patches/Dataset_128BS_64PS_64SS_CoRegistered/{net_mode} Target/'
    test_dataset = SatelliteDataset(dirs, normalization_options, patches_options, readWhileRunning, readFromPatches)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
else:
    dataset = SatelliteDataset(dirs, normalization_options, patches_options, readWhileRunning, readFromPatches)
    train_set_len, test_set_len = (len(dataset)//2, len(dataset)//2) if len(dataset) % 2 == 0 else (len(dataset)//2 + 1, len(dataset)//2)
    train_loader, test_loader = torch.utils.data.random_split(dataset, [train_set_len, test_set_len])
    train_loader = torch.utils.data.DataLoader(train_loader, batch_size=batch_size, drop_last=False)
    test_loader = torch.utils.data.DataLoader(test_loader, batch_size=batch_size, drop_last=False)

## Training

In [None]:
train(train_loader, None)

## Testing

In [None]:
dirModel = os.getcwd() + '/drive/My Drive/TFG/Models/SAM-GAN_CiudadReal_All_epoch400.pth'
test(test_loader, dirModel=dirModel)