In [None]:
import sys
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import logging


import torch
import torch.nn as nn
from torch import optim
import torch.backends.cudnn as cudnn
from torchvision import transforms
from torch.utils.data import DataLoader

from IOU_computations import *
from Data_Handle.dataset_generator import Dataset_sat
from predict_and_evaluate import *
from Data_Handle.data_augmentation import *

import json
from random import randint


In [None]:
#to update in real time the notebook figures
%matplotlib notebook 

## Necessary parameters and Load model

#### WATCH OUT  Parameters should be set equal to the ones which has been trained the model with

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')

###################################

INPUT_CHANNELS=9 #9 channels for panchromatic + 8 pansharpened. If not set to 9, plotting of patches will mess up.
                # so only works for INPUT_CHANNELS=9 anyway.

NB_CLASSES=2 #Building and Background. Only works for NB_CLASSES=2 anyway, otherwise this network doesn't work.
SIZE_PATCH=128# patches of size 128x128. Needs to be equal to the size of the patches of the dataset.
############## 


###############
DEFAULT_LAYERS=3 #number of layers of the UNET (not considering bottom layer) = number of downsmapling stages
DEFAULT_FEATURES_ROOT=32 # number of filters in the first layer of the Unet
DEFAULT_BN=True # Batch normalization layers included

#####

DEFAULT_FILTER_WIDTH=3 #convolution kernel size. ex, here: 3x3
DEFAULT_LR=1e-3#1e-3for spacenet and ghana
DEFAULT_N_RESBLOCKS=1 #can add residual blocks inside each stage. Make the network heavier. Not advised.

########
DISTANCE_NET='v2' #can be set to none if no distance module wants to be used
BINS=10
THRESHOLD=20

if DISTANCE_NET is None:
    DISTANCE_NET_UNET=False # has to be set to False if no distance module is used, otherwise error.
else:
    DISTANCE_NET_UNET=True

    
#type of loss used 
LOSS_FN='jaccard_approx'#or 'cross-entropy'

# root_folder='../SPACENET_DATA/SPACENET_DATA_PROCESSED/DATASET/128_x_128_8_bands_pansh/'
root_folder ='../2_DATA_GHANA/DATASET/128_x_128_8_pansh/'
# root_folder ='../2_DATA_GHANA_BALANCED/DATASET/128_x_128_8_pansh/'
path_model='TRAINED_MODELS/RUBV3D2_final_model_ghana.pth'
# path_model='TRAINED_MODELS/RUBV3D2_final_model_spacenet.pth'

from RUBV3D2 import UNet 
model=UNet(INPUT_CHANNELS,NB_CLASSES,depth =DEFAULT_LAYERS,n_features_zero =DEFAULT_FEATURES_ROOT,width_kernel=DEFAULT_FILTER_WIDTH,dropout=0,distance_net=DISTANCE_NET_UNET,bins=BINS,batch_norm=DEFAULT_BN)
model.cuda()
cudnn.benchmark = True
model.load_state_dict(torch.load(path_model))



In [None]:
batch_size=32
SAVE_PATCHES=True
test_generator=Dataset_sat.from_root_folder(root_folder+'TEST/',NB_CLASSES)


test_loader = DataLoader(test_generator, batch_size=batch_size,shuffle=True, num_workers=1)


##
all_prediction_path='TEST_SET_GHANA/'
if not os.path.exists(all_prediction_path):
    os.makedirs(all_prediction_path)
file_results=open(all_prediction_path+'models_ghana.txt','w')



model_name='RUBVRD2'

TMP_IOU=all_prediction_path+'TMP_IOU/'
if not os.path.exists(TMP_IOU):
    os.makedirs(TMP_IOU)

loss_v=0
error_rate_v=0
iou_acc_v=0
f1_v=0
    


for i_batch,sample in enumerate(test_loader):
    
    
    predict_net=Train_or_Predict(sample,DISTANCE_NET,LOSS_FN,THRESHOLD,BINS,model)
    loss,probs_dist,probs_seg=predict_net.forward_pass()

    prediction_seg_v=probs_seg.data.cpu().numpy()
    groundtruth_seg_v=np.asarray(predict_net.batch_y)
    prediction_dist_v=probs_dist.data.cpu().numpy()
    groundtruth_dist=np.asarray(predict_net.batch_y_dist)
    loss_v+=loss.data[0]
    error_rate_v+=error_rate(prediction_seg_v,groundtruth_seg_v)
    iou_acc,f1,_=predict_score_batch(TMP_IOU,np.argmax(groundtruth_seg_v,3),np.argmax(prediction_seg_v,3))
    iou_acc_v+=iou_acc
    f1_v+=f1
    
loss_v/=test_loader.__len__()   
error_rate_v/=test_loader.__len__() 
iou_acc_v/=test_loader.__len__()  
f1_v/=test_loader.__len__()  
logging.info("Model {:s}: Verif loss = {:.4f}, Verification  Error rate = {:.4f}%, IOU Precision = {:.4f}%, F1 IOU= {:.4f}%".format(model_name,loss_v,error_rate_v,iou_acc_v,f1_v))

file_results.write(json.dumps({'name': model_name, 'loss':loss_v, 'Erro rate':error_rate_v, 'IOU accuracy':iou_acc_v, 'F1 IOU': f1_v}))
file_results.flush()