# Inferência

Neste notebook, vamos aplicar modelos treinado no conjunto de teste

Precisa ter os modelos gerados na pasta /models.

In [None]:
import torch

def print_gpu_memory(prefix=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024 ** 2)
        reserved = torch.cuda.memory_reserved() / (1024 ** 2)
        print(f"{prefix} Memory Allocated: {allocated:.2f} MB")
        print(f"{prefix} Memory Reserved: {reserved:.2f} MB")
    else:
        print("CUDA is not available.")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache() 

print_gpu_memory()

In [None]:
# imports

import os
import sys
sys.path.append(os.path.abspath('..'))

import src.data.preprocess_data as data
import src.training.train_model as train
import src.data.view as view
import src.models.unets as unets
import src.models.hrnets as hrnets
import src.training.post_processing as post

from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt


Seleção do modelo, deve estar na pasta models

In [None]:
model_name = 'UNetSmall-256-type-CE-diceW.pth'

model_name = 'UNetSmall-64-type-CEW.pth' #predicts lots of 4
model_name = 'UNetConvNext-224-type-DS-CEW.pth' #do teh opposite, predicts as otehr classes
model_name = 'UNetSmall-256-type-DS-CEW.pth' #do teh opposite, predicts as otehr classes
model_name = 'UNet-256-type-DS-CE.pth'
#model_name = 'UNetConvNext-224-type-CEW.pth'
#model_name = 'HRNetW32-512-type-CEW.pth'



Carrega-se os parâmetros de acordo com o nome do modelo

In [None]:

patch_size = int(model_name.split('-')[1])
weighted = 'W.pth' in model_name
if model_name.split('-')[2]=='type':
    num_classes = 5
if model_name.split('-')[2]=='binary':
    num_classes = 2
loss_mode = model_name.split('-')[-1].split('.')[0]
if loss_mode.endswith('W'):
    loss_mode = loss_mode[:-1]
crf=False
dist=False
if model_name.split('-')[3]=='crf':
    crf = True
if model_name.split('-')[3]=='dist':
    dist = True

criterion = train.CombinedLoss(loss_mode = loss_mode, weights = None, return_all_losses=True)


Instancia e Carrega o checkpoint do modelo, de acordo com o prefixo dele

In [None]:


working_dir = os.path.abspath('..')
ckp_file = os.path.join(working_dir, 'models', model_name)

print('Model name:', model_name)
if model_name.startswith('UNetSmall-'):
    model = unets.UNetSmall(in_channels=12, out_channels=num_classes, crf=crf, use_dist=dist).to(device) 
if model_name.startswith('UNet-'):
    model = unets.UNet(in_channels=12, out_channels=num_classes).to(device) 
elif model_name.startswith('UNetResNet34-'):
    model = unets.UNetResNet34(in_channels=12, out_channels=num_classes).to(device) 
elif model_name.startswith('UNetEfficientNetB0-'):
    model = unets.UNetEfficientNetB0(in_channels=12, out_channels=num_classes).to(device) 
elif model_name.startswith('UNetConvNext-'):
    model = unets.UNetConvNext(in_channels=12, out_channels=num_classes).to(device) 
elif model_name.startswith('HRNetW18'):
    model = hrnets.HRNetSegmentation(in_channels= 12, num_classes=num_classes, backbone="hrnet_w18_small", pretrained=True,).to(device)
elif model_name.startswith('HRNetW32'):
    model = hrnets.HRNetSegmentation(in_channels= 12, num_classes=num_classes, backbone="hrnet_w32", pretrained=True,).to(device)
elif model_name.startswith('HRNetW48'):
    model = hrnets.HRNetSegmentation(in_channels= 12, num_classes=num_classes, backbone="hrnet_w48", pretrained=True,).to(device)
checkpoint = torch.load(ckp_file, weights_only=False)
        
start_epoch = checkpoint['epoch'] + 1
best_epoch = checkpoint['best_epoch']
model.load_state_dict(checkpoint['best_model_state_dict'])
#optimizer.load_state_dict(checkpoint['best_optimizer_state_dict']) 
best_val_loss = checkpoint['best_val_loss']
best_epoch_info = checkpoint['best_epoch_info']
current_lr = checkpoint['current_lr']
current_patience = checkpoint['current_parience']       
metadata = checkpoint['metadata']
info = checkpoint['best_epoch_info']
history = checkpoint['history']


Mostra as estatisticas do treinamento (uma época)

In [None]:
for info in history:
    print(info['train_acc'])
    print(info)
    break


In [None]:
import matplotlib.pyplot as plt
def plot_col(history, column):

    train_metric = [info[f'train_{column}'] for info in history]
    val_metric = [info[f'val_{column}'] for info in history]
    print(train_metric)
    
    fig, ax = plt.subplots(figsize=(8, 6))
    epochs = range(1, len(train_metric) + 1)

    # Plot training and validation losses
    ax.plot(epochs, train_metric, label=f"Training {column}", marker="o", linestyle="-", color="blue")
    ax.plot(epochs, val_metric, label=f"Validation {column}", marker="o", linestyle="-", color="red")

    # Add labels, title, and legend
    ax.set_title(f"Training and Validation {column}")
    ax.set_xlabel("Epoch")
    ax.set_ylabel(column)
    ax.legend()
    ax.grid(True)


columns = ['loss', 'acc', 'micro', 'macro', 'weighted', 'f1_C0', 'f1_C1', 'f1_C2', 'f1_C3', 'f1_C4', 'CE', 'dice']    
#for c in columns:
#    plot_col(history, c)

def get_best(history):
    lowest_loss = np.inf
    info_best = {}
    for info in history:
        if info['val_loss'] <= lowest_loss:
            lowest_loss = info['val_loss']
            info_best = info
    return info


Visualiza o loss pelas épocas, do modelo treinado

In [None]:
columns = ['loss', 'acc', 'micro', 'macro', 'weighted', 'f1_C0', 'f1_C1', 'f1_C2', 'f1_C3', 'f1_C4', 'CE', 'dice']   
for c in columns:
    view.plot_metrics(history, c)


# Aplicação no conjunto de teste

Definições

In [None]:
tiles = ['032027']#, '032026'] 
num_subtiles = 6
classes_mode = 'type'
training_batch_size = 16
model_types = 'unets'
weighted = True

if classes_mode == 'type':
    num_classes = 5
elif classes_mode == 'density':
    num_classes = 4
elif classes_mode == 'binary':
    num_classes = 2
elif classes_mode == 'all':
    num_classes = 9


working_dir = os.path.abspath('..')
models_paths = os.listdir(os.path.join(working_dir, 'models'))
models_paths = [f for f in models_paths if (f.endswith('CE.pth') or f.endswith('CEW.pth'))]
models_paths.sort()



In [None]:
print(models_paths)

In [None]:
train_files, val_files, test_files = data.train_val_test_stratify(tiles, 
                                                                  num_subtiles,
                                                                    train_size = 0.6, 
                                                                    val_size = 0.2, 
                                                                    stratify_by = classes_mode,
                                                                    subfolder='q_12ch')

## Loop de teste

Criamos estrutura de dados que armazena as métricas relacionadas ao recall

In [None]:
metric = 'recall'
highest_binary_macro = {'model':'', 'value': 0.0}
highest_4class_macro = {'model':'', 'value': 0.0}
highest_5class_macro = {'model':'', 'value': 0.0}

highest_binary_weighted = {'model':'', 'value': 0.0}
highest_4class_weighted = {'model':'', 'value': 0.0}
highest_5class_weighted = {'model':'', 'value': 0.0}

highest_binary_global = {'model':'', 'value': 0.0}
highest_4class_global = {'model':'', 'value': 0.0}
highest_5class_global = {'model':'', 'value': 0.0}

binary_labels = ["(0,3)", "(1,2,4)"]
join_labels = ["(0,3)", "(1)", "(2)", "(4)"]
class_labels = list(range(num_classes))
class_labels = [str(cl) for cl in class_labels]

highest_binary = [{'model':'', 'class':'(0,3)','value': 0.0}, {'model':'', 'class':'(1,2,4)', 'value': 0.0}]
highest_4class = [{'model':'', 'class':'(0,3)', 'value': 0.0}, {'model':'', 'class':'1', 'value': 0.0}, {'model':'', 'class':'2', 'value': 0.0}, {'model':'', 'class':'4', 'value': 0.0}]
highest_5class = [{'model':'', 'class':'0', 'value': 0.0}, {'model':'', 'class':'1', 'value': 0.0}, {'model':'', 'class':'2', 'value': 0.0}, {'model':'', 'class':'3', 'value': 0.0}, {'model':'', 'class':'4', 'value': 0.0}]


highest_precision_global = {'model':'', 'value': 0.0}
highest_f1_global = {'model':'', 'value': 0.0}
highest_precision_macro = {'model':'', 'value': 0.0}
highest_f1_macro = {'model':'', 'value': 0.0}
highest_precision = [{'model':'', 'class':'0', 'value': 0.0}, {'model':'', 'class':'1', 'value': 0.0}, {'model':'', 'class':'2', 'value': 0.0}, {'model':'', 'class':'3', 'value': 0.0}, {'model':'', 'class':'4', 'value': 0.0}]
highest_f1 = [{'model':'', 'class':'0', 'value': 0.0}, {'model':'', 'class':'1', 'value': 0.0}, {'model':'', 'class':'2', 'value': 0.0}, {'model':'', 'class':'3', 'value': 0.0}, {'model':'', 'class':'4', 'value': 0.0}]

ordered_f1s = {'macro avg':[], 'weighted avg': [], 'Class 0':[], 'Class 1':[], 'Class 2':[], 'Class 3':[], 'Class 4':[]} 
ordered_recalls = {'macro avg':[], 'weighted avg': [], 'Class 0':[], 'Class 1':[], 'Class 2':[], 'Class 3':[], 'Class 4':[]} 
ordered_precisions = {'macro avg':[], 'weighted avg': [], 'Class 0':[], 'Class 1':[], 'Class 2':[], 'Class 3':[], 'Class 4':[]} 

all_reports = {}
for model_name in [mp for mp in models_paths]:# if 'UNet-256-type-DS-CEW' in mp]:#model_paths:
#for model_name in model_paths:#[mp for mp in models_paths if 'UNet-256-type-DS-CEW' in mp]:#model_paths:
    print('Model name:', model_name)
    if model_name.startswith('UNetSmall-'):
        model = unets.UNetSmall(in_channels=12, out_channels=num_classes, crf=crf, use_dist=dist).to(device) 
    elif model_name.startswith('UNet-'):
        model = unets.UNet(in_channels=12, out_channels=num_classes).to(device) 
    elif model_name.startswith('UNetResNet34-'):
        model = unets.UNetResNet34(in_channels=12, out_channels=num_classes).to(device) 
    elif model_name.startswith('UNetEfficientNetB0-'):
        model = unets.UNetEfficientNetB0(in_channels=12, out_channels=num_classes).to(device) 
    elif model_name.startswith('UNetConvNext-'):
        model = unets.UNetConvNext(in_channels=12, out_channels=num_classes).to(device) 
    elif model_name.startswith('HRNetW18'):
        model = hrnets.HRNetSegmentation(in_channels= 12, num_classes=num_classes, backbone="hrnet_w18_small", pretrained=True,).to(device)
    elif model_name.startswith('HRNetW32'):
        model = hrnets.HRNetSegmentation(in_channels= 12, num_classes=num_classes, backbone="hrnet_w32", pretrained=True,).to(device)
    elif model_name.startswith('HRNetW48'):
        model = hrnets.HRNetSegmentation(in_channels= 12, num_classes=num_classes, backbone="hrnet_w48", pretrained=True,).to(device)
    else:
        print('Nao existe esse modelo')
        continue
    checkpoint = torch.load(ckp_file, weights_only=False)

    if model_name.startswith('UNet'):
        BS = 16
    elif model_name.startswith('HRNet'):
        BS = 4
    else:
        print('não existe esse modelo')
        continue

    yaml_filename = data.yaml_filename(num_subtiles, tiles, classes_mode)
    print(yaml_filename)
    print('----------------')
    test_dataset = data.SubtileDataset(yaml_filename, 
                                    set = 'test_files',
                                    patch_size=patch_size, 
                                    stride=patch_size, 
                                    dynamic_sampling = False,
                                    data_augmentation = False, # testando 
                                    )

    dataloader = DataLoader(test_dataset, batch_size=BS, shuffle=False)

    cm, report = train.test_model(model, 
                    checkpoint_path=model_name,
                    dataloader = dataloader, 
                    device = device, 
                    num_classes = num_classes,
                    set_name = '-test-032027'
                    )
    all_reports[model_name] = report
    if 0:
        print(recalls['binary']) 
        print(recalls['binary']['macro_recall'])
        print(highest_binary_macro['value'])
        if recalls['binary']['macro_recall']>highest_binary_macro['value']:
            highest_binary_macro  = {'model':model_name, 'value':recalls['binary']['macro_recall']}
        if recalls['binary']['weighted_recall']>highest_binary_weighted['value']:
            highest_binary_weighted = {'model':model_name, 'value':recalls['binary']['weighted_recall']}
        if recalls['binary']['global_recall']>highest_binary_global['value']:
            highest_binary_global = {'model':model_name, 'value':recalls['binary']['global_recall']}

        if recalls['4class']['macro_recall']>highest_4class_macro['value']:
            highest_4class_macro = {'model':model_name, 'value':recalls['4class']['macro_recall']}
        if recalls['4class']['weighted_recall']>highest_4class_weighted['value']:
            highest_4class_weighted = {'model':model_name, 'value':recalls['4class']['weighted_recall']}
        if recalls['4class']['global_recall']>highest_4class_global['value']:
            highest_4class_global = {'model':model_name, 'value':recalls['4class']['global_recall']}

        if recalls['5class']['macro_recall']>highest_5class_macro['value']:
            highest_5class_macro = {'model':model_name, 'value':recalls['5class']['macro_recall']}
        if recalls['5class']['weighted_recall']>highest_5class_weighted['value']:
            highest_5class_weighted = {'model':model_name, 'value':recalls['5class']['weighted_recall']}
        if recalls['5class']['global_recall']>highest_5class_global['value']:
            highest_5class_global = {'model':model_name, 'value':recalls['5class']['global_recall']}


        for i, bl in enumerate(binary_labels):
            if recalls['binary'][bl+'_recall']>highest_binary[i]['value']:
                highest_binary[i]['model'] = model_name
                highest_binary[i]['value'] = recalls['binary'][bl+'_recall']
        for i, jl in enumerate(join_labels):
            if recalls['4class'][jl+'_recall']>highest_4class[i]['value']:
                highest_4class[i]['model'] = model_name
                highest_4class[i]['value'] = recalls['4class'][jl+'_recall']
        for i, cl in enumerate(class_labels):
            if recalls['5class'][cl+'_recall']>highest_5class[i]['value']:
                highest_5class[i]['model'] = model_name
                highest_5class[i]['value'] = recalls['5class'][cl+'_recall']

    if report['recall']['weighted avg']>highest_precision_global['value']:
        highest_precision_global = {'model':model_name, 'value':report['recall']['weighted avg']}
    if report['recall']['macro avg']>highest_precision_macro['value']:
        highest_precision_macro = {'model':model_name, 'value':report['recall']['macro avg']}
    if report['precision']['weighted avg']>highest_precision_global['value']:
        highest_precision_global = {'model':model_name, 'value':report['precision']['weighted avg']}
    if report['precision']['macro avg']>highest_precision_macro['value']:
        highest_precision_macro = {'model':model_name, 'value':report['precision']['macro avg']}
    if report['f1-score']['weighted avg']>highest_f1_global['value']:
        highest_f1_global = {'model':model_name, 'value':report['f1-score']['weighted avg']}
    if report['f1-score']['macro avg']>highest_f1_macro['value']:
        highest_f1_macro = {'model':model_name, 'value':report['f1-score']['macro avg']}

    for i, cl in enumerate(class_labels):
        if report['f1-score'][f'Class {i}']>highest_f1[i]['value']:
            highest_f1[i]['model'] = model_name
            highest_f1[i]['value'] = report['f1-score'][f'Class {i}']
    for i, cl in enumerate(class_labels):
        if report['precision'][f'Class {i}']>highest_precision[i]['value']:
            highest_precision[i]['model'] = model_name
            highest_precision[i]['value'] = report['precision'][f'Class {i}']
    for i, cl in enumerate(class_labels):
        if report['recall'][f'Class {i}']>highest_precision[i]['value']:
            highest_precision[i]['model'] = model_name
            highest_precision[i]['value'] = report['recall'][f'Class {i}']


    ordered_f1s['weighted avg'].append({'model': model_name, 'value': report['f1-score']['weighted avg']})
    ordered_f1s['macro avg'].append({'model': model_name, 'value': report['f1-score']['macro avg']})
    for i, cl in enumerate(class_labels):
        ordered_f1s[f'Class {i}'].append({'model': model_name, 'value': report['f1-score'][f'Class {i}']})
    
    ordered_recalls['weighted avg'].append({'model': model_name, 'value': report['recall']['weighted avg']})
    ordered_recalls['macro avg'].append({'model': model_name, 'value': report['recall']['macro avg']})
    for i, cl in enumerate(class_labels):
        ordered_recalls[f'Class {i}'].append({'model': model_name, 'value': report['recall'][f'Class {i}']})

    ordered_precisions['weighted avg'].append({'model': model_name, 'value': report['precision']['weighted avg']})
    ordered_precisions['macro avg'].append({'model': model_name, 'value': report['precision']['macro avg']})
    for i, cl in enumerate(class_labels):
        ordered_precisions[f'Class {i}'].append({'model': model_name, 'value': report['precision'][f'Class {i}']})



ordered_f1s['weighted avg']=sorted(ordered_f1s['weighted avg'], key=lambda x: x['value'], reverse=True)[:5]
ordered_f1s['macro avg']=sorted(ordered_f1s['macro avg'], key=lambda x: x['value'], reverse=True)[:5]
for i, cl in enumerate(class_labels):
    ordered_f1s[f'Class {i}'] = sorted(ordered_f1s[f'Class {i}'], key=lambda x: x['value'], reverse=True)[:5]
    
ordered_recalls['weighted avg']=sorted(ordered_recalls['weighted avg'], key=lambda x: x['value'], reverse=True)[:5]
ordered_recalls['macro avg']=sorted(ordered_recalls['macro avg'], key=lambda x: x['value'], reverse=True)[:5]
for i, cl in enumerate(class_labels):
    ordered_recalls[f'Class {i}'] = sorted(ordered_recalls[f'Class {i}'], key=lambda x: x['value'], reverse=True)[:5]

ordered_precisions['weighted avg']=sorted(ordered_precisions['weighted avg'], key=lambda x: x['value'], reverse=True)[:5]
ordered_precisions['macro avg']=sorted(ordered_precisions['macro avg'], key=lambda x: x['value'], reverse=True)[:5]
for i, cl in enumerate(class_labels):
    ordered_precisions[f'Class {i}'] = sorted(ordered_precisions[f'Class {i}'], key=lambda x: x['value'], reverse=True)[:5]


In [None]:

print('Binary macro:', highest_binary_macro)
print('Binary weighted:', highest_binary_weighted)
print('Binary global:', highest_binary_global)
for i, bl in enumerate(binary_labels):
    print(f'Binary Recall, class {bl}', highest_binary[i])
print()
print('4 class macro:', highest_4class_macro)
print('4 class weighted:', highest_4class_weighted)
print('4 class global:', highest_4class_global)
for i, jl in enumerate(join_labels):
    print(f'4 class Recall, class {jl}', highest_4class[i])
print()
print('5 class macro:', highest_5class_macro)
print('5 class weighted:', highest_5class_weighted)
print('5 class global:', highest_5class_global)
for i, cl in enumerate(class_labels):
    print(f'5 class Recall, class {cl}', highest_5class[i])



In [None]:
print('F1-score: ')
print('macro:', highest_f1_macro)
print('global:', highest_f1_global)
for i, cl in enumerate(class_labels):
    print(f'f1, class {cl}', highest_f1[i])

print('precision: ')
print('macro:', highest_precision_macro)
print('global:', highest_precision_global)
for i, cl in enumerate(class_labels):
    print(f'precision, class {cl}', highest_precision[i])


In [None]:
all_reports

In [None]:

names = list(all_reports.keys())  # X-axis labels

recalls = []
rec_macro = []
rec_wei = []

for k,v in all_reports.items():
    print(k)
    prec=[]
    rec = []
    prec_t=[]
    rec_t = []
    for i in range(5):
        rec.append(v['recall'][f'Class {i}'])
    recalls.append(rec)
    
print(v['recall'])
suffixes = ['-type-CE.pth', '-type-CEW.pth', '-type-DS-CE.pth', '-type-DS-CEW.pth']
save_to = os.path.join(working_dir,"figs","test_recall.png")
view.plot_metric(recalls, names, suffixes, metric_name="Recall", save_to=save_to)


### Divisão por conjunto de classes (binária, 4 classes e 5 classes)

In [None]:
train_files, val_files, test_files = data.train_val_test_stratify(tiles, 
                                                                  num_subtiles,
                                                                    train_size = 0.6, 
                                                                    val_size = 0.2, 
                                                                    stratify_by = 'type')


highest_binary_macro = {'model':'', 'value': 0.0}
highest_4class_macro = {'model':'', 'value': 0.0}
highest_5class_macro = {'model':'', 'value': 0.0}

highest_binary_weighted = {'model':'', 'value': 0.0}
highest_4class_weighted = {'model':'', 'value': 0.0}
highest_5class_weighted = {'model':'', 'value': 0.0}

highest_binary_global = {'model':'', 'value': 0.0}
highest_4class_global = {'model':'', 'value': 0.0}
highest_5class_global = {'model':'', 'value': 0.0}


binary_labels = ["(0,3)", "(1,2,4)"]
join_labels = ["(0,3)", "(1)", "(2)", "(4)"]
class_labels = list(range(num_classes))
class_labels = [str(cl) for cl in class_labels]

highest_binary = [{'model':'', 'class':'(0,3)','value': 0.0}, {'model':'', 'class':'(1,2,4)', 'value': 0.0}]
highest_4class = [{'model':'', 'class':'(0,3)', 'value': 0.0}, {'model':'', 'class':'1', 'value': 0.0}, {'model':'', 'class':'2', 'value': 0.0}, {'model':'', 'class':'4', 'value': 0.0}]
highest_5class = [{'model':'', 'class':'0', 'value': 0.0}, {'model':'', 'class':'1', 'value': 0.0}, {'model':'', 'class':'2', 'value': 0.0}, {'model':'', 'class':'3', 'value': 0.0}, {'model':'', 'class':'4', 'value': 0.0}]

for model_name in [mp for mp in models_paths if 'UNet-256-type-DS-CEW' in mp]:#model_paths:
#for model_name in models_paths:
    print('Model name:', model_name)
    if model_name.startswith('UNetSmall-'):
        model = unets.UNetSmall(in_channels=12, out_channels=num_classes, crf=crf, use_dist=dist).to(device) 
    elif model_name.startswith('UNet-'):
        model = unets.UNet(in_channels=12, out_channels=num_classes).to(device) 
    elif model_name.startswith('UNetResNet34-'):
        model = unets.UNetResNet34(in_channels=12, out_channels=num_classes).to(device) 
    elif model_name.startswith('UNetEfficientNetB0-'):
        model = unets.UNetEfficientNetB0(in_channels=12, out_channels=num_classes).to(device) 
    elif model_name.startswith('UNetConvNext-'):
        model = unets.UNetConvNext(in_channels=12, out_channels=num_classes).to(device) 
    elif model_name.startswith('HRNetW18'):
        model = hrnets.HRNetSegmentation(in_channels= 12, num_classes=num_classes, backbone="hrnet_w18_small", pretrained=True,).to(device)
    elif model_name.startswith('HRNetW32'):
        model = hrnets.HRNetSegmentation(in_channels= 12, num_classes=num_classes, backbone="hrnet_w32", pretrained=True,).to(device)
    elif model_name.startswith('HRNetW48'):
        model = hrnets.HRNetSegmentation(in_channels= 12, num_classes=num_classes, backbone="hrnet_w48", pretrained=True,).to(device)
    else:
        print('Nao existe esse modelo')
        continue
    checkpoint = torch.load(ckp_file, weights_only=False)

    if model_name.startswith('UNet'):
        BS = 16
    elif model_name.startswith('HRNet'):
        BS = 4
    else:
        print('não existe esse modelo')
        continue

    full_dataset = data.SubtileDataset(train_files+val_files+test_files, 
                                            num_subtiles = num_subtiles, 
                                            classes_mode = 'type', 
                                            patch_size = patch_size, 
                                            stride=patch_size, # sem overlap
                                            dynamic_sampling = False,
                                            data_augmentation = False
                                           )

    dataloader = DataLoader(full_dataset, batch_size=BS, shuffle=False)

    recalls = train.test_model(model, 
                    checkpoint_path=model_name,
                    dataloader = dataloader, 
                    device = device, 
                    num_classes = num_classes,
                    set_name = '-full-032027'
                    ) 
    print(recalls)
    if recalls['binary']['macro_recall']>highest_binary_macro['value']:
        highest_binary_macro = {'model':model_name, 'value':recalls['binary']['macro_recall']}
    if recalls['binary']['weighted_recall']>highest_binary_weighted['value']:
        highest_binary_weighted = {'model':model_name, 'value':recalls['binary']['weighted_recall']}
    if recalls['binary']['global_recall']>highest_binary_global['value']:
        highest_binary_global = {'model':model_name, 'value':recalls['binary']['global_recall']}

    if recalls['4class']['macro_recall']>highest_4class_macro['value']:
        highest_4class_macro = {'model':model_name, 'value':recalls['4class']['macro_recall']}
    if recalls['4class']['weighted_recall']>highest_4class_weighted['value']:
        highest_4class_weighted = {'model':model_name, 'value':recalls['4class']['weighted_recall']}
    if recalls['4class']['global_recall']>highest_4class_global['value']:
        highest_4class_global = {'model':model_name, 'value':recalls['4class']['global_recall']}

    if recalls['5class']['macro_recall']>highest_5class_macro['value']:
        highest_5class_macro = {'model':model_name, 'value':recalls['5class']['macro_recall']}
    if recalls['5class']['weighted_recall']>highest_5class_weighted['value']:
        highest_5class_weighted = {'model':model_name, 'value':recalls['5class']['weighted_recall']}
    if recalls['5class']['global_recall']>highest_5class_global['value']:
        highest_5class_global = {'model':model_name, 'value':recalls['5class']['global_recall']}


    for i, bl in enumerate(binary_labels):
        if recalls['binary'][bl+'_recall']>highest_binary[i]['value']:
            highest_binary[i]['model'] = model_name
            highest_binary[i]['value'] = recalls['binary'][bl+'_recall']
    for i, jl in enumerate(join_labels):
        if recalls['4class'][jl+'_recall']>highest_4class[i]['value']:
            highest_4class[i]['model'] = model_name
            highest_4class[i]['value'] = recalls['4class'][jl+'_recall']
    for i, cl in enumerate(class_labels):
        if recalls['5class'][cl+'_recall']>highest_5class[i]['value']:
            highest_5class[i]['model'] = model_name
            highest_5class[i]['value'] = recalls['5class'][cl+'_recall']




In [None]:

print('Binary macro:', highest_binary_macro)
print('Binary weighted:', highest_binary_weighted)
print('Binary global:', highest_binary_global)
for i, bl in enumerate(binary_labels):
    print(f'Binary Recall, class {bl}', highest_binary[i])
print()
print('4 class macro:', highest_4class_macro)
print('4 class weighted:', highest_4class_weighted)
print('4 class global:', highest_4class_global)
for i, jl in enumerate(join_labels):
    print(f'4 class Recall, class {jl}', highest_4class[i])
print()
print('5 class macro:', highest_5class_macro)
print('5 class weighted:', highest_5class_weighted)
print('5 class global:', highest_5class_global)
for i, cl in enumerate(class_labels):
    print(f'5 class Recall, class {cl}', highest_5class[i])

### Experimentos com a limpeza morfológica da inferência do conjunto de teste teste

In [None]:
tile = '032027'

folder = os.path.join(working_dir,f"data/processed/S2-16D_V2_{tile}/{num_subtiles}x{num_subtiles}_subtiles")
files = os.listdir(folder)
files = [os.path.join(folder, f) for f in files if f.endswith('.tif')]





In [None]:
model_selection = []

In [None]:
stride = patch_size-32
edge_removal = 8
if patch_size == 64:
    stride = patch_size-16
    edge_removal = 4

for tile_id in ['032027']:#, '025037', '032027']:
    # --------------- opening files -----------------
    folder = os.path.join(working_dir,f"data/processed/S2-16D_V2_{tile_id}/{num_subtiles}x{num_subtiles}_subtiles")
    files = os.listdir(folder)
    files = [os.path.join(folder, f) for f in files if f.endswith('.tif')]
    # --------------- creating a dataloader -----------------
    test_dataset = data.SubtileDataset(files, 
                                    num_subtiles = 6,
                                    classes_mode=classes_mode,
                                    patch_size=patch_size, 
                                    stride=stride, #//2, 
                                    dynamic_sampling = False,
                                    data_augmentation = False, # testando 
                                    return_imgidx = True)
    test_loader = DataLoader(test_dataset, batch_size=training_batch_size, shuffle=False)
    #for image, mask, x, y, f in test_loader:
    #    print(f'{x},{y}',end='|')
    tile = post.ReconstructTile(patch_size = patch_size, stride = stride, edge_removal=edge_removal)

    ### -------------- TESTING -------------------
    import time
    torch.cuda.reset_peak_memory_stats()
    run_time = time.time()
    runner = train.EpochRunner('test', model, test_loader, criterion=criterion, num_classes=num_classes, 
                                optimizer=None, simulated_batch_size = test_loader.batch_size, device = device)  
    for image, label, logits, pred, x, y, f, in runner.run_generator(show_pred = 1):
        tile.add_batch(x, y, f, logits, pred, label, image)
    print('Montando')
    tile.set_pred()      
        
    loss, CE, dice, report, acc, cm = runner.get_metrics()
    run_time = time.time()-run_time
    peak_train_memory = f"{torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
    torch.cuda.empty_cache()

    
    print(f'Test Loss: {loss}, {CE}, {dice}')
    print(f'Test Accuracy: {acc}')
    print(f'Test confusion matrix:')
    view.plot_confusion_matrix(cm)
    print(report)
    
    #### ----------------- Contruct tile
    
    

    r = [0, 10560, 0, 10560]
    #r = [5000, 7000, 1000, 3000]

    plt.figure(figsize=(50, 50))
    plt.subplot(1,2,1)
    plt.imshow(tile.labels[r[0]:r[1], r[2]:r[3]])
    plt.subplot(1,2,2)
    plt.imshow(tile.preds[r[0]:r[1], r[2]:r[3]])
    plt.show()
    break

In [None]:
#print(1504+256)
#plt.figure(figsize=(15,15))
#plt.imshow(prob[4,230:280, 230:280])
labels, pred_patch, clean_pred, clean_noholes, clean_noholes_2, noholes, noholes2, rules = tile.post_process(0,0)
titles = ["Referência", "Predição do modelo", "CRF", "CRF + buracos preenchidos", "CRF + buracos preenchidos 2", "Com buracos preenchidos", "Com buracos preenchidos 2", "Regras"]
view.plot_mask_list([labels, pred_patch, clean_pred, clean_noholes, clean_noholes_2, noholes, noholes2, rules])

In [None]:
#clean, crf, Q, clean_crf, 

label_map = {
        0: ('Reds', 0.0, 'Fundo'),
        1: ('Blues', 0.8, 'Loteamento Vazio'),
        2: ('Greens', 0.8, 'Outros Equipamentos'),
        3: ('Reds', 0.8, 'Vazio Intraurbano'),
        4: ('Oranges', 0.8, 'Área Urbanizada')
    }
fig, axs = plt.subplots(4, 2, figsize=(80, 20))
axs = axs.flatten()
print(axs)
view.plot_masked_image(labels, label_map, image=None, title="Original", ax=axs[0])
view.plot_masked_image(pred, label_map, image=None, title="Original", ax=axs[1])
view.plot_masked_image(clean_pred, label_map, image=None, title="Original", ax=axs[2])
view.plot_masked_image(vazios, label_map, image=None, title="Original", ax=axs[3])
view.plot_masked_image(crfc1, label_map, image=None, title="Original", ax=axs[4])
view.plot_masked_image(crfc1, label_map, image=None, title="Original", ax=axs[5])
view.plot_masked_image(c1, label_map, image=None, title="Original", ax=axs[6])
view.plot_masked_image(c2, label_map, image=None, title="Original", ax=axs[7])

if 0:
    plt.figure(figsize=(20, 40))
    plt.subplot(4,2,1)
    plt.imshow(labels)
    plt.subplot(4,2,2)
    plt.imshow(pred)
    plt.subplot(4,2,3)
    plt.imshow(clean_pred)
    plt.subplot(4,2,4)
    plt.imshow(vazios)
    plt.subplot(4,2,5)
    plt.imshow(crfc1)
    plt.subplot(4,2,6)
    plt.imshow(crfc2)
    plt.subplot(4,2,7)
    plt.imshow(c1)
    plt.subplot(4,2,8)
    plt.imshow(c2)






In [None]:
print(slg)

In [None]:
from skimage.morphology import remove_small_objects

def remove_small(final_mask, min_size = 10):
    final_mask = remove_small_objects(final_mask, min_size=min_size)
    return final_mask

a = remove_small(tile.crf, min_size = 25)
plt.imshow(a)

In [None]:

import matplotlib.pyplot as plt
plt.figure(figsize=(10,10))
t = tile.tile
t=t-np.min(t)
t/=np.max(t)

r = [3000, 7000, 0, 4000]
#plt.imshow(t, cmap='gray')
plt.figure(figsize=(50, 50))
plt.subplot(1,2,1)
plt.imshow(tile.labels[r[0]:r[1],r[2]:r[3]])
plt.subplot(1,2,2)
plt.imshow(tile.preds[r[0]:r[1],r[2]:r[3]])
#plt.subplot(1,3,3)
#plt.imshow(tile_032026.cleaned_preds[3000:7000,0:4000])
plt.show()

In [None]:
view.plt_tile(tile.labels[3000:7000,0:4000], tile.preds[3000:7000,0:4000])


In [None]:
label_map = {
        0: ('Reds', 0.0, 'Fundo'),
        1: ('Blues', 0.8, 'Loteamento Vazio'),
        2: ('Greens', 0.8, 'Outros Equipamentos'),
        3: ('Reds', 0.8, 'Vazio Intraurbano'),
        4: ('Oranges', 0.8, 'Área Urbanizada')
    }
fig, axs = plt.subplots(1, 2, figsize=(16, 8))

view.plot_masked_image(tile.labels[3000:7000,0:4000], label_map, image=None, title="Mask Overlay Only", ax=axs[0])
view.plot_masked_image(tile.preds[3000:7000,0:4000], label_map, image=None, title="Mask Overlay Only", ax=axs[1])
plt.tight_layout()
plt.show()

fig, axs = plt.subplots(2, 2, figsize=(16, 16))
view.plot_masked_image(tile.cleaned_preds[3000:7000,0:4000], label_map, image=None, title="Cleaned", ax=axs[0])
view.plot_masked_image(tile.crf[3000:7000,0:4000], label_map, image=None, title="CRF", ax=axs[1])
view.plot_masked_image(tile.cleaned_crf[3000:7000,0:4000], label_map, image=None, title="CRF+Clean", ax=axs[2])
view.plot_masked_image(tile.vi_completion[3000:7000,0:4000], label_map, image=None, title="+completion", ax=axs[3])
plt.tight_layout()
plt.show()


In [None]:
np.max(pca_img)


In [None]:
#plt.figure(figsize=(10,10))
#plt.imshow(tile_032026.logits)
tile_032026.preds[:200,:200]

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(tile_032026.logits_0)

In [None]:
x = (256, 256+32)
y = (2256, 2256+32)

plt.figure(figsize=(10,10))
plt.imshow(tile_032026.logits[x[0]:x[1],y[0]:y[1]])
plt.figure(figsize=(10,10))
plt.imshow(tile_032026.logits_0[x[0]:x[1],y[0]:y[1]])
plt.figure(figsize=(10,10))
plt.imshow(tile_032026.labels[x[0]:x[1],y[0]:y[1]])

In [None]:
import rasterio
import matplotlib.pyplot as plt
import numpy as np

path = '/home/jonathan/UrbanizedAreasSegmentation/data/results/S2-16D_V2_032027/6x6_subtiles/UNet-256-type-DS-CEW/S2-16D_V2_032027.tif'
with rasterio.open(path) as src:
    data = src.read()
    meta = src.meta
plt.imshow(np.squeeze(data))
plt.show()