In [None]:
from __future__ import print_function
import numpy as np
import math
import scipy
import pandas as pd
import PIL
import gdal
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import sys, os
from pathlib import Path
import time
import xml.etree.ElementTree as ET
import random
import collections, functools, operator
import csv

import ee

from osgeo import gdal,osr
from gdalconst import *
import subprocess
from osgeo.gdalconst import GA_Update

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.autograd import Variable
from torch.nn import Linear, ReLU, CrossEntropyLoss, MSELoss, Sequential, Conv2d, MaxPool2d, Module, Softmax, BatchNorm2d, Dropout, Sigmoid
from torch.optim import Adam, SGD
from torchvision import transforms, utils

import skimage
from skimage import io, transform
import sklearn
import sklearn.metrics
from sklearn.feature_extraction import image
from sklearn import svm

# CNNR architecture

In [None]:
class CNNnet(Module):
    def __init__(self):
        super(CNNnet, self).__init__()

        self.cnn_layers = Sequential(
            # Defining a 2D convolution layer
            Conv2d(9, 32, kernel_size=3, stride=1, padding=1),
            ReLU(inplace=True),
            # Defining a 2D convolution layer
            Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
            ReLU(inplace=True),
            # Defining a 2D convolution layer
            Conv2d(16, 8, kernel_size=3, stride=1, padding=1),
            ReLU(inplace=True),
        )

        self.linear_layers = Sequential(
            Linear(8*5*5, 170), # the input shape will be dependant of the (c*w*w)
            ReLU(inplace=True),
            Dropout(0.5),
            Linear(170, 170),
        )

    # Defining the forward pass    
    def forward(self, x):
        x = self.cnn_layers(x)
        x = torch.flatten(x, start_dim=1)
        x = self.linear_layers(x)
        return x

def extract_miniPatches(input, target):
    input = input[0].numpy()
    target = target[0].numpy()
    patchesInput=[]
    patchesTarget=[]
    for p in range(input.shape[0]):
        bandsInput=[]
        bandsTarget=[]
        for b in range(input.shape[1]):
            bandsInput.append(skimage.util.shape.view_as_windows(np.pad(input[p,b], 2, 'edge'), (5,5), step=1))
        for b in range(target.shape[1]):
            bandsTarget.append(target[p,b])
        patchesInput.append(bandsInput)
        patchesTarget.append(bandsTarget)

    input = np.array(patchesInput)
    input = input.transpose((0,2,3,1,4,5)) # (128, 9, 64, 64, 5, 5) -> (128, 64, 64, 9, 5, 5)
    input = input.reshape(input.shape[0]*input.shape[1]*input.shape[2], input.shape[3], input.shape[4], input.shape[5]) # (128, 64, 64, 9, 5, 5) -> (128*64*64, 9, 5, 5)
    input = torch.from_numpy(input)
    target = np.array(patchesTarget)
    target = target.transpose((0,2,3,1)) # (128, 170, 64, 64) -> (128, 64, 64, 170)
    target = target.reshape(target.shape[0]*target.shape[1]*target.shape[2], target.shape[3]) # (128, 64, 64, 170) -> (128*64*64, 170)
    target = torch.from_numpy(target)
    return input, target


def CNNtrain(train_loader):
    # Ensures network is in train mode
    net.train()
    # empty list to store training losses
    train_losses = []
    # Loop over epochs
    for epoch in range(num_epochs):
        epoch_losses = 0
        for sample in train_loader: # Each sample is a big patch divided in minipatches
            input = sample['input']
            target = sample['target']

            if readFromPatches:
                input, target = extract_miniPatches(input, target)
            else:
                input = input.resize(64*64,9,5,5)
                target = target.resize(64*64,170)
            

            # Converting the data into GPU format
            input = input.to(device)
            target = target.to(device)

            # clearing the Gradients of the model parameters
            optimizer.zero_grad()

            # Acquires the network's best guesses
            prediction = net(input)


            # Computes loss
            loss = loss_fn(prediction, target)
            epoch_losses += loss.item()
            # Computing the updated weights of all the model parameters
            loss.backward()
            optimizer.step()
        train_losses.append(epoch_losses / len(train_loader))

        ### Visualization code ###
        if epoch % 1 == 0:
            print(f"Epoch {epoch}: CNN loss: {train_losses[-1]}")

            if readFromPatches:
                target = target.reshape(len(sample['input'][0]),64,64,170).permute(0,3,1,2).detach().cpu().numpy()
                prediction = prediction.reshape(len(sample['input'][0]),64,64,170).permute(0,3,1,2).detach().cpu().numpy()
                input = input.reshape(len(sample['input'][0]),64,64,9,5,5).permute(0,3,1,2,4,5)[:,:,:,:,2,2].detach().cpu().numpy()
            else:
                target = target.transpose(1,0).permute(170,64,64).detach().cpu().numpy()
                prediction = prediction.transpose(1,0).permute(170,64,64).detach().cpu().numpy()
                input = input.reshape(64,64,9,5,5).permute(2,0,1,3,4)[:,:,:,2,2].detach().cpu().numpy()
            

            show_patches(input, prediction, target)

            metrics_batch = calc_metrics(target, prediction, verbose=True)
            
            torch.save({'net': net.state_dict(),
                        'net_opt': optimizer.state_dict()
                    }, os.getcwd() + f"/drive/My Drive/TFG/Models/CNN_CiudadReal_All_progress/epoch{epoch}.pth")


def CNNtest(inferenceDataset, vizImages=False, svc=None, saveMetrics=None):
    metrics = {'PCC': np.array([0.]*170),
               'RMSE': np.array([0.]*170),
               'PSNR': np.array([0.]*170),
               'SSIM': np.array([0.]*170),
               'SAM': np.array([0.]*64*64),
               'SID': np.array([0.]*64*64)}
    for i, sample in enumerate(inferenceDataset):
        input = sample['input']
        target = sample['target']
        if svc != None:
            crop.append(sample['crop'].numpy())
        
        input_p, target = extract_miniPatches(input, target)

        input = input[0].numpy()
        target = target.numpy()
        prediction = net(input_p.to(device)).detach().cpu().numpy()



        target = target.reshape(-1, 64, 64, 170).transpose(0,3,1,2)
        prediction = prediction.reshape(-1, 64, 64, 170).transpose(0,3,1,2)
        
        # VISUALIZATION
        if vizImages:
            show_patches(input, prediction, target)
        break
        # BATCH EVALUATION
        metrics_batch = calc_metrics(target, prediction, verbose=False)
        # BAND-WISE EVALUATION
        metrics['PCC'] += metrics_batch['PCC']
        metrics['RMSE'] += metrics_batch['RMSE']
        metrics['PSNR'] += metrics_batch['PSNR']
        metrics['SSIM'] += metrics_batch['SSIM']
        # PIXEL-WISE EVALUATION
        metrics['SAM'] += metrics_batch['SAM']
        metrics['SID'] += metrics_batch['SID']

        '''
        if saveMetrics != None:
            metrics = {k: np.mean(m) for k,m in metrics.items()}
            df = pd.DataFrame({key: pd.Series(value) for key, value in metrics.items()})
            df.to_csv(os.getcwd() + f"/drive/My Drive/TFG/Metrics/CNN_metrics/{saveMetrics}.csv", encoding='utf-8', index=False)
            break
        '''

        # CROP CLASSIFICATION
        if svc != None:
            crop = np.array(crop)
            crop_class, pred_class = svc.test(crop, prediction)
            print('Accuracy:', sklearn.metrics.accuracy_score(crop_class, pred_class))
    
    # DATASET EVALUATION
    metrics = {k: m/len(inferenceDataset) for k,m in metrics.items()}
    show_metrics(metrics)

    if saveMetrics != None:
        df = pd.DataFrame({key: pd.Series(value) for key, value in metrics.items()})
        df.to_csv(os.getcwd() + f"/drive/My Drive/TFG/Metrics/CNN_metrics/{saveMetrics}.csv", encoding='utf-8', index=False)
    