In [None]:
import numpy as np
import math
import scipy
import pandas as pd
import PIL
import gdal
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import sys, os
from pathlib import Path
import time
import random
import collections, functools, operator
import csv
import subprocess
import datetime

from osgeo import gdal,osr
from gdalconst import *
import subprocess
from osgeo.gdalconst import GA_Update

import torch
import torch.nn as nn

import skimage
from skimage import io, transform
import sklearn
import sklearn.metrics
from sklearn.feature_extraction import image
from sklearn import svm

# Metrics and utils for model training and evaluation

In [None]:
def pcc(img1, img2): # Pearson's Correlation Coefficient
    num = np.sum( (img1-np.mean(img1)) * (img2-np.mean(img2)) )
    denom = np.sqrt(np.sum((img1-np.mean(img1))**2)) * np.sqrt(np.sum((img2-np.mean(img2))**2))
    if denom==0.:
        return 0
    PCC = np.divide(num, denom)
    return PCC
def rmse(img1, img2): # Relative Mean Squared Error
    RMSE = np.sqrt(np.mean((img1 - img2) ** 2))
    return RMSE
def psnr(img1, img2): # Peak Signal-to-Noise Ratio
    PSNR = skimage.metrics.peak_signal_noise_ratio(img1, img2, data_range=img1.max() - img2.min())
    return PSNR
def ssim(img1, img2): # Structural Similarity
    SSIM = skimage.metrics.structural_similarity(img1, img2, data_range=img1.max() - img1.min())
    return SSIM
def sam(px1, px2): # Spectral Angle Mapper
    num = np.dot(px1, px2)
    denom = np.sqrt(np.dot(px1, px1)) * np.sqrt(np.dot(px2, px2))
    SAM = math.acos(num / denom)
    return SAM
def sid(px1, px2): # Spectral Information Divergence
    if all(px2) == 0.:
        return 1.
    p = np.divide(px1, np.sum(px1))
    q = np.divide(px2, np.sum(px2))
    Dxy = np.sum(p * np.log(np.divide(p,q)))
    Dyx = np.sum(q * np.log(np.divide(q,p)))
    SID = Dxy + Dyx
    return SID

def bandWise_metrics(batch_real, batch_predicted):
    #BAND-WISE EVALUATION
    PCC = []
    RMSE = []
    PSNR = []
    SSIM = []
    for band in range(batch_real.shape[1]): # Iterate through bands
        PCCband = 0
        RMSEband = 0
        PSNRband = 0
        SSIMband = 0
        for real_band, predicted_band in zip(batch_real[:,band], batch_predicted[:,band]): # Iterate through all patches in batch
            PCCband += pcc(real_band, predicted_band)
            RMSEband += rmse(real_band, predicted_band)
            PSNRband += psnr(real_band, predicted_band)
            SSIMband += ssim(real_band, predicted_band)
        PCC.append(PCCband / batch_real.shape[0])
        RMSE.append(RMSEband / batch_real.shape[0])
        PSNR.append(PSNRband / batch_real.shape[0])
        SSIM.append(SSIMband / batch_real.shape[0])
    bwm = {'PCC': np.array(PCC), 'RMSE': np.array(RMSE), 'PSNR': np.array(PSNR), 'SSIM': np.array(SSIM)}
    return bwm

def pixelWise_metrics(batch_real, batch_predicted):
    #PIXEL-WISE EVALUATION
    SAM = []
    SID = []
    for i in range(batch_real.shape[2]):
        for j in range(batch_real.shape[3]):
            SAMpix = 0
            SIDpix = []
            for real_pix, predicted_pix in zip(batch_real[:,:,i,j], batch_predicted[:,:,i,j]):
                SAMpix += sam(real_pix, predicted_pix)
                SIDpix.append(sid(real_pix, predicted_pix))
            SAM.append(SAMpix / batch_real.shape[0])
            SIDpix = np.array(SIDpix)
            SIDpix[np.isnan(SIDpix)] = np.mean(SIDpix[~np.isnan(SIDpix)])
            SIDpix = np.sum(SIDpix)
            SID.append(SIDpix / batch_real.shape[0])
    pwm = {'SAM': np.array(SAM), 'SID': np.array(SID)}
    return pwm


def calc_metrics(real, predicted, verbose=False):
    #BAND-WISE EVALUATION
    bwm = bandWise_metrics(real, predicted)
    PCC = bwm['PCC']
    RMSE = bwm['RMSE']
    PSNR = bwm['PSNR']
    SSIM = bwm['SSIM']
    
    #PIXEL-WISE EVALUATION
    pwm = pixelWise_metrics(real, predicted)
    SAM = pwm['SAM']
    SID = pwm['SID']

    metrics = {'PCC': PCC,
               'RMSE': RMSE,
               'PSNR': PSNR,
               'SSIM': SSIM,
               'SAM': SAM,
               'SID': SID}
    if verbose:
        show_metrics(metrics)
    return metrics

def show_metrics(metrics):
    print('PCC:',np.mean(metrics['PCC']))
    print('RMSE:',np.mean(metrics['RMSE']))
    print('PSNR:',np.mean(metrics['PSNR']))
    print('SSIM:',np.mean(metrics['SSIM']))
    print('SAM:',np.mean(metrics['SAM']))
    print('SID:',np.mean(metrics['SID']))

    fig, axs = plt.subplots(2,3)
    axs[0, 0].boxplot(metrics['PCC'], showmeans=True)
    axs[0, 0].set_title('PCC')
    axs[0, 1].boxplot(metrics['RMSE'], showmeans=True)
    axs[0, 1].set_title('RMSE')
    axs[1, 0].boxplot(metrics['PSNR'], showmeans=True)
    axs[1, 0].set_title('PSNR')
    axs[1, 1].boxplot(metrics['SSIM'], showmeans=True)
    axs[1, 1].set_title('SSIM')
    axs[0, 2].boxplot(metrics['SAM'], showmeans=True)
    axs[0, 2].set_title('SAM')
    axs[1, 2].boxplot(metrics['SID'], showmeans=True)
    axs[1, 2].set_title('SID')
    plt.show()



def show_patches(input, prediction, target, saveSinglePatches=False): # Shows a few 64x64 patches from a batch
    fig, axs = plt.subplots(3,10)
    fig.set_figwidth(15)
    for patch in range(10):
        i = [ (input[patch][4]*65535)/4000, (input[patch][3]*65535)/4000, (input[patch][2]*65535)/4000 ]
        p = [ (prediction[patch][160]*65535)/4000, (prediction[patch][45]*65535)/4000, (prediction[patch][9]*65535)/4000 ]
        t = [ (target[patch][160]*65535)/4000,(target[patch][45]*65535)/4000,(target[patch][9]*65535)/4000 ]

        for y in (i, p, t):
            for x in y:
                x[x<0] = 0
                x[x>1] = 1

        i = np.dstack((i[0],i[1],i[2]))
        p = np.dstack((p[0],p[1],p[2]))
        t = np.dstack((t[0],t[1],t[2]))
        axs[0, patch].imshow(i)
        axs[0, patch].set_title('Input')
        axs[1, patch].imshow(p)
        axs[1, patch].set_title('Prediction')
        axs[2, patch].imshow(t)
        axs[2, patch].set_title('Target')
    plt.show()

    if saveSinglePatches:
        for i, patch in enumerate([0, 2, 4, 6]):
            p = [ (prediction[patch][160]*65535)/4000, (prediction[patch][45]*65535)/4000, (prediction[patch][9]*65535)/4000 ] # Bands 45 is 37, 21 is 13, 14 is 6

            for y in (p, t):
                for x in y:
                    x[x<0] = 0
                    x[x>1] = 1
                    
            p = np.dstack((p[0],p[1],p[2]))
            print('predicted')
            img = plt.imshow(p)
            plt.axis('off')
            plt.savefig(os.getcwd() + '/drive/My Drive/TFG/PatchesViz/' + f"SAMGAN_epoch400_p{i}.png", bbox_inches='tight')
            plt.show()

            
            print('real')
            img = plt.imshow(t)
            plt.axis('off')
            plt.savefig(os.getcwd() + '/drive/My Drive/TFG/PatchesViz/' + f"GT_p{i}.png", bbox_inches='tight')
            plt.show()
    



def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)



class svmClassifier(object):
    def __init__(self):
        self.clf = svm.SVC(kernel='poly', degree=2, gamma='auto', C=1.0, decision_function_shape = "ovr")
        self.PCs = [2, 6, 14, 23, 36, 43, 81, 110, 153]

    def train(self, train_data):
        spectral_data=[]
        crop_class=[]
        for sample in train_data:
            spectral_data.append(net(sample['input'].resize(64*64,9,5,5).to(device)).detach().cpu().numpy())
            crop_class.append(sample['crop'].numpy())

        

        spectral_data = np.array(spectral_data).reshape(len(spectral_data), 64, 64, 170).transpose(0,3,1,2)
        crop_class = np.array(crop_class)
        spectral_data = spectral_data.transpose(0,2,3,1).reshape(spectral_data.shape[0]*spectral_data.shape[2]*spectral_data.shape[3], spectral_data.shape[1])
        crop_class = crop_class.transpose(0,2,3,1).reshape(crop_class.shape[0]*crop_class.shape[2]*crop_class.shape[3], crop_class.shape[1])

        
        spectral_PCs = []
        for i in self.PCs: 
            spectral_PCs.append(spectral_data[:,i])
        spectra_data = np.array(spectral_PCs)
        


        self.clf.fit(spectral_data, crop_class)
    
    def test(self, crop_class, spectral_data):
        spectral_data = spectral_data.transpose(0,2,3,1).reshape(spectral_data.shape[0]*spectral_data.shape[2]*spectral_data.shape[3], spectral_data.shape[1])
        crop_class = crop_class.transpose(0,2,3,1).reshape(crop_class.shape[0]*crop_class.shape[2]*crop_class.shape[3], crop_class.shape[1])

        
        PCs = [2, 6, 14, 23, 36, 43, 81, 110, 153]
        spectral_PCs = []
        for i in self.PCs: 
            spectral_PCs.append(spectral_data[:,i])
        spectra_data = np.array(spectral_PCs)

        return crop_class, self.clf.predict(spectral_data)