In [None]:
import os
import sys
import glob
import json 
import pprint
import matplotlib
import copy
import pickle

from matplotlib import pyplot as plt
import numpy as np
import matplotlib as mpl


In [None]:
mpl.rcParams['figure.dpi'] = 100


In [None]:
MODEL_TRAINING_OUTPUT_BASE_DIR = '/media/data/local/corn/out/from_drive/'


def get_dir_path_for_model(model_name):
    return os.path.join(MODEL_TRAINING_OUTPUT_BASE_DIR, 'model_' + model_name)

# Load data

In [None]:
dir_names = [name for name in os.listdir(MODEL_TRAINING_OUTPUT_BASE_DIR) if os.path.isdir(os.path.join(MODEL_TRAINING_OUTPUT_BASE_DIR, name))]
results = {}


for dir_name in dir_names:
    if not dir_name.startswith('model_'):
        print(f'Skipping directory {dir_name}')
        continue
    
    dir_path = os.path.join(MODEL_TRAINING_OUTPUT_BASE_DIR, dir_name)
    model_result_files = glob.glob(os.path.join(dir_path, 'model_*_result.json'))
    if len(model_result_files) != 1:
        print(f'Invalid directory results in {dir_name}')
        continue
        
    model_result_file_path = model_result_files[0]
    
    def remove_prefix(text, prefix):
        return text[text.startswith(prefix) and len(prefix):]

    model_name = remove_prefix(dir_name, 'model_')
    with open(model_result_file_path, 'rt') as file:
        results[model_name] = json.load(file)
        

In [None]:
results[next(iter(results))].keys()

# Parse results to metrics

In [None]:
test_results_metrics = {
    'fscore': {},
    'iou_score': {},
#     'IoU_Class0': {},
    'IoU_Class1': {},
#     'IoU_Class2': {},
    'accuracy': {},
    'dice_loss': {},
}

valid_progresses = {}

for model_name, result in results.items():    
    for metric_name in test_results_metrics.keys():
        test_results_metrics[metric_name][model_name] = result['test_log'][metric_name]
    
    valid_logs = result['valid_logs_vec']
    valid_progresses[model_name] = [log['dice_loss'] for log in valid_logs]
    
    


# Printing/Plotting utils

In [None]:
def plot_metrics_for_models(models_to_plot=None):
    metrics_limited = test_results_metrics
    if models_to_plot is not None:
        metrics_limited = copy.deepcopy(test_results_metrics)
        for metric_name, metric_data in test_results_metrics.items():
            for model_name in metric_data.keys():
                if model_name not in models_to_plot:
                    del metrics_limited[metric_name][model_name]
        
    for metric_name in metrics_limited.keys():            
        data_dict = metrics_limited[metric_name]
        data_dict = dict(sorted(data_dict.items(), key=lambda item: item[1], reverse=True))
        keys = data_dict.keys()
        values = data_dict.values()
        plt.bar(keys, values)
        plt.xticks(rotation='vertical')
        plt.ylim([min(data_dict.values())-0.1, max(data_dict.values())+0.1])
        plt.title(metric_name)
        plt.grid()
        plt.show()
        
        
def print_metrics_for_models(models_to_print=None, metrics_to_print=None):   
    if metrics_to_print is None:
        metrics_to_print = list(test_results_metrics.keys())
    if models_to_print is None:
        models_to_print = list(test_results_metrics[metrics_to_print[0]].keys())
    
    
    table_txt = '\ntile size / metric \t' + '\t'.join(metrics_to_print) + '\n'
    for model_name in models_to_print:
        table_txt += f'{model_name}\t'
        for metric_name in metrics_to_print:
            value = test_results_metrics[metric_name][model_name]
            print(f'{metric_name} for {model_name} = {value:.3f}')
            table_txt += f'{value:.3f}\t'
        table_txt += '\n\n'
        
    print(table_txt)

    

# Print/plot comparisson for tile size

In [None]:
MODELS_TO_PRINT = [
    'UNET_PLUS_PLUS__EFFICIENT_NET_B0__384px', 
    'UNET_PLUS_PLUS__EFFICIENT_NET_B0__1152px', 
    'UNET_PLUS_PLUS__EFFICIENT_NET_B0'
]

print_metrics_for_models(models_to_print=MODELS_TO_PRINT, metrics_to_print=['fscore', 'IoU_Class1', 'accuracy'])

In [None]:
fig = matplotlib.pyplot.gcf()
fig.set_size_inches(10, 7)

legends_raw = []
for model_name, valid_vec in valid_progresses.items():
    if model_name not in MODELS_TO_PRINT:
        continue
            
    plt.plot(valid_vec[:27])
    legends_raw.append(model_name)
    
    
legend_map = {
    'UNET_PLUS_PLUS__EFFICIENT_NET_B0__384px': 'Tile size 256 pixels (7.68 m)',
    'UNET_PLUS_PLUS__EFFICIENT_NET_B0__1152px': 'Tile size 768 pixels (23.04 m)', 
    'UNET_PLUS_PLUS__EFFICIENT_NET_B0': 'Tile size 512 pixels (15.36 m)',
}

legends = list(map(legend_map.get, legends_raw))
plt.title('Training progress on validation data (model UNET++)')
plt.grid()
plt.legend(legends)
plt.xlabel('epoch number')
plt.ylabel('dice loss')
plt.show()

# Print/plot comparisson for all models

## Print all metrics

In [None]:
list(test_results_metrics['fscore'].keys())

In [None]:
test_results_metrics

In [None]:
plot_metrics_for_models()

# Print/plot for encoder comparison

In [None]:
MODELS_TO_PRINT = [
    'UNET_PLUS_PLUS__EFFICIENT_NET_B0',
    'UNET_PLUS_PLUS__EFFICIENT_NET_B1',
    'UNET_PLUS_PLUS__EFFICIENT_NET_B2',
    'UNET_PLUS_PLUS__EFFICIENT_NET_B3__big_lr_as_always',
    'UNET_PLUS_PLUS__EFFICIENT_NET_B4',
    'UNET_PLUS_PLUS__RESNET50_1',
    # 'UNET_PLUS_PLUS__RESNET50', # same as above UNET_PLUS_PLUS__RESNET50_1
    'UNET_PLUS_PLUS__RESNET18',
]

# plot_metrics_for_models(MODELS_TO_PRINT)
print_metrics_for_models(models_to_print=MODELS_TO_PRINT, metrics_to_print=['fscore', 'IoU_Class1', 'accuracy'])

In [None]:
fig = matplotlib.pyplot.gcf()
fig.set_size_inches(10, 7)

legends_raw = []
for model_name, valid_vec in valid_progresses.items():
    if model_name not in MODELS_TO_PRINT:
        continue
            
    plt.plot(valid_vec[:])
    legends_raw.append(model_name)
    
    

plt.title('Training progress on validation data (model UNET++)')
plt.grid()
plt.legend(legends_raw)
plt.xlabel('epoch number')
plt.ylabel('dice loss')
plt.show()

# Print/plot models comparison

In [None]:
MODELS_TO_PRINT = [
 'PAN',
 'DEEP_LAB_V3_pc_seed_778',
 'FPN',
 'UNET_PLUS_PLUS__EFFICIENT_NET_B3__big_lr_as_always',
 'DEEP_LAB_V3_PLUS',
#  'DEEP_LAB_V3__default_seed_colab',
 'UNET',
#  'DEEP_LAB_V3__seed_776_colab',
 'LINKNET',
]

plot_metrics_for_models(MODELS_TO_PRINT)
print_metrics_for_models(models_to_print=MODELS_TO_PRINT, metrics_to_print=['fscore', 'IoU_Class1', 'accuracy'])

# Print/plot comparison for NDVI data 

In [None]:
MODELS_TO_PRINT = [
#  'UNET_PLUS_PLUS__EFFICIENT_NET_B0_ndvi',
#  'UNET_PLUS_PLUS__EFFICIENT_NET_B0_ndvi_file_but_rgb_only',
#  'UNET_PLUS_PLUS__EFFICIENT_NET_B0_only_ndvi',
 'UNET_PLUS_PLUS__EFFICIENT_NET_B3_only_ndvi',
 'UNET_PLUS_PLUS__EFFICIENT_NET_B3_ndvi_but_only_rgb',
 'UNET_PLUS_PLUS__EFFICIENT_NET_B3_ndvi2'
]

plot_metrics_for_models(MODELS_TO_PRINT)
print_metrics_for_models(models_to_print=MODELS_TO_PRINT, metrics_to_print=['fscore', 'IoU_Class1', 'accuracy'])

In [None]:
fig = matplotlib.pyplot.gcf()
fig.set_size_inches(10, 7)

legends_raw = []
for model_name, valid_vec in valid_progresses.items():
    if model_name not in MODELS_TO_PRINT:
        continue
            
    plt.plot(valid_vec[:])
    legends_raw.append(model_name)
    
    

plt.title('Training progress on validation data (model UNET++)')
plt.grid()
plt.legend(legends_raw)
plt.xlabel('epoch number')
plt.ylabel('dice loss')
plt.show()

# Plot images for NDVI data

In [None]:
def get_prefiction_figure_for_model(model_name: str, prediction_group: int):
    model_dir_path = get_dir_path_for_model(model_name)
    pickle_fig_path = os.path.join(model_dir_path, 'figures', f'test_predictions_{prediction_group}.pickle')

    with open(pickle_fig_path, 'rb') as file:
        figx = pickle.load(file)
        
    return figx


mpl.rcParams['figure.dpi'] = 200


model_rgb_and_ndvi = 'UNET_PLUS_PLUS__EFFICIENT_NET_B3_ndvi2'
model_only_ndvi = 'UNET_PLUS_PLUS__EFFICIENT_NET_B3_only_ndvi'


PREDICTION_GROUP = 0  # 0 to 11
ROW = 0  # 0 to 3

PREDICTION_GROUP, ROW = (2, 1)
# PREDICTION_GROUP, ROW = (6, 0)
# PREDICTION_GROUP, ROW = (4, 0)


only_ndvi_fig = get_prefiction_figure_for_model(model_only_ndvi, prediction_group=PREDICTION_GROUP)
rgb_and_ndvi_fig = get_prefiction_figure_for_model(model_rgb_and_ndvi, prediction_group=PREDICTION_GROUP)




# read the data from the plot
ndvi_image = only_ndvi_fig.axes[ROW*8 + 0].images[0].get_array()
mask_image = only_ndvi_fig.axes[ROW*8 + 1].images[0].get_array()
only_ndvi_prediction = only_ndvi_fig.axes[ROW*8 + 2].images[0].get_array()

rgb_image = rgb_and_ndvi_fig.axes[ROW*9 + 0].images[0].get_array()
rgb_ndvi_prediction = rgb_and_ndvi_fig.axes[ROW*9 + 3].images[0].get_array()




fontsize = 15
title_offset_y = -0.13
columns = 5
rows = 1
fig = plt.figure(figsize=(columns * 4, rows * 4))
# fig.suptitle('Example training dataset images')  # or plt.suptitle('Main title')
column_terator = 1

fig.add_subplot(rows, columns, column_terator)
plt.imshow(rgb_image)
plt.axis('off')
plt.title('a) RGB image', fontsize=fontsize, y=title_offset_y)
column_terator += 1

fig.add_subplot(rows, columns, column_terator)
plt.imshow(ndvi_image)
plt.axis('off')
plt.title('b) NDVI image', fontsize=fontsize, y=title_offset_y)
column_terator += 1

fig.add_subplot(rows, columns, column_terator)
plt.imshow(mask_image)
plt.axis('off')
plt.title('c) damaged area mask', fontsize=fontsize, y=title_offset_y)
column_terator += 1

fig.add_subplot(rows, columns, column_terator)
plt.imshow(only_ndvi_prediction)
plt.axis('off')
plt.title('d) prediction on NDVI', fontsize=fontsize, y=title_offset_y)
column_terator += 1

fig.add_subplot(rows, columns, column_terator)
plt.imshow(rgb_ndvi_prediction)
plt.axis('off')
plt.title('e) prediction on RGB+NDVI', fontsize=fontsize, y=title_offset_y)
column_terator += 1





In [None]:
# dummy = plt.figure()
# new_manager = dummy.canvas.manager
# new_manager.canvas.figure = rgb_and_ndvi_fig
# rgb_and_ndvi_fig.set_canvas(new_manager.canvas)