# Montagem dos tiles e pos-processamento

In [None]:
import torch

def print_gpu_memory(prefix=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024 ** 2)
        reserved = torch.cuda.memory_reserved() / (1024 ** 2)
        print(f"{prefix} Memory Allocated: {allocated:.2f} MB")
        print(f"{prefix} Memory Reserved: {reserved:.2f} MB")
    else:
        print("CUDA is not available.")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache() 

print_gpu_memory()

In [None]:
# imports

import os
import sys
sys.path.append(os.path.abspath('..'))

import src.data.preprocess_data as data
import src.training.train_model as train
import src.data.view as view
import src.models.unets as unets
import src.models.hrnets as hrnets
import src.training.post_processing as post

from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

In [None]:


channels_dict = {}
channels_dict[12] = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B11', 'B12', 'B8A']
channels_dict[10] = ['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B11', 'B12', 'B8A']
channels_dict[8] = ['B02', 'B03', 'B04', 'B05', 'B06', 'B08', 'B11', 'B12']
channels_dict[6] = ['B02', 'B03', 'B04', 'B06', 'B08', 'B11']
channels_dict[4] = ['B02', 'B03', 'B04','B08']

In [None]:
num_subtiles = 6
num_classes = 4

working_dir = os.path.abspath('..')

classes_mode = '4types'
if classes_mode == 'type':
    num_classes = 5
elif classes_mode == 'density':
    num_classes = 4
elif classes_mode == '4types':
    num_classes = 4
elif classes_mode == 'all':
    num_classes = 9
else:
    num_classes = 2


### exemplo:

Primeiramente, vamos montar um tile de exemplo, a partir das predições de um modelo fixo
A seguir, definimos o modelo e parametros

In [None]:
tile_id = '025037'
ch = 4
model_name = f'UNetSmall-256-4types-DS-CEW-{ch}ch-4tt'
patch_size = 256
BS = 16 # batch size

# define qual o dado de entrada para inferencia 
folder = os.path.join(working_dir,f"data/processed/S2-16D_V2_{tile_id}/{num_subtiles}x{num_subtiles}_subtiles")
files = os.listdir(folder)
files = [os.path.join(folder, f) for f in files if f.endswith('.tif')]

# carrega o modelo
print('Model name:', model_name)
if model_name.startswith('UNetSmall-'):
    model = unets.UNetSmall(in_channels=ch, out_channels=num_classes).to(device) 

ckp_file = os.path.join(working_dir, 'models', model_name+'.pth')
checkpoint = torch.load(ckp_file, weights_only=False)
model.load_state_dict(checkpoint['best_model_state_dict'])

                
#parametros da montagem
stride = patch_size-32
edge_removal = 8

# --------------- opening files -----------------
folder = os.path.join(working_dir,f"data/processed/S2-16D_V2_{tile_id}/{num_subtiles}x{num_subtiles}_subtiles/q_12ch")
files = os.listdir(folder)
files = [os.path.join(folder, f) for f in files if f.endswith('.tif')]

# --------------- creating a dataloader -----------------
indices = [i for i, value in enumerate(channels_dict[ch]) if value in channels_dict[12]]

tile_dataset = data.SubtileDataset(files, 
                                num_subtiles = 6,
                                classes_mode=classes_mode,
                                patch_size=patch_size, 
                                stride=stride, #//2, 
                                dynamic_sampling = False,
                                data_augmentation = False,
                                channels_subset= indices,
                                return_imgidx = True)
dataloader = DataLoader(tile_dataset, batch_size=BS, shuffle=False)



# Montando o tile

Criamos o objeto que guarda as reconstruções.

Em seguida, faz a inferência usando o modelo, (em runner.run_generator) e adiciona os batches no reconstrutor (tile.add_batch)

Em seguida, faz a predição geral do tile (em tile.set_pred()).

Por fim, faz o pos processamento (em tile.post_process_tile) e salva os resultados (tile.save_pred)

In [None]:

# aqui cria o objeto de reconstrucao do tile
tile = post.ReconstructTile(patch_size = patch_size, stride = stride, edge_removal=edge_removal, 
                            num_classes=num_classes, num_channels=ch, tile_id=tile_id)


import time
torch.cuda.reset_peak_memory_stats()
run_time = time.time()

print('Inferindo...')        
runner = train.EpochRunner('test', model, dataloader, num_classes=num_classes, 
                            simulated_batch_size = dataloader.batch_size, device = device)  
for image, label, logits, pred, x, y, f, in runner.run_generator(show_pred = 0):
    tile.add_batch(x, y, f, logits, pred, label, image)
print('Montando...')
tile.set_pred()   
#print('Pos processamento...')
#labels, pred_patch, clean_pred, clean_noholes, clean_noholes_2, noholes, noholes2, rules = tile.post_process(0,0)

#returned = tile.post_process(x, y)
print('Salvando...')        
tile.save_pred(folder_name = model_name)
    
loss, CE, dice, report, acc, cm = runner.get_metrics()
run_time = time.time()-run_time
peak_train_memory = f"{torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
torch.cuda.empty_cache()

# Replace with actual variable


print(f'Test Loss: {loss}, {CE}, {dice}')
print(f'Test Accuracy: {acc}')
#print(f'Test confusion matrix:')
view.plot_single_confusion_matrix(cm)
print(report)

#### ----------------- Contruct tile



r = [0, 10560, 0, 10560]
#r = [5000, 7000, 1000, 3000]

plt.figure(figsize=(50, 50))
plt.subplot(1,2,1)
plt.imshow(tile.labels[r[0]:r[1], r[2]:r[3]])
plt.subplot(1,2,2)
plt.imshow(tile.preds[r[0]:r[1], r[2]:r[3]])
plt.show()



## Visualizando a limpeza das predições

podemos aplicar a limpeza para cada patch do tile que desejamos.

Vamos aplicar num pedaço do tile (iniciando em 0,0, e indo até 1760)

Pode alterar esses valores para ver outras posicoes do tile.

In [None]:
x_pos, y_pos = 0, 1000
size = 1760
labels, pred_patch, crf, morph, crf_morph, sieve = tile.post_process_patch(x_pos, y_pos, size)
titles = ["Referência", "Predição do modelo", "CRF", "Limpeza morfológica"]
rgb = tile.rgb_image[x_pos:x_pos+size, y_pos:y_pos+size,:]
save_to = os.path.join(working_dir, 'figs', 'post_clean_1.png')
view.plot_mask_list([labels, pred_patch, crf, morph], background = rgb, titles = titles, num_classes = 4, save_to=save_to)




# Fazendo a limpeza do tile inteiro

In [None]:

tile.post_process_tile(folder_name = model_name)

### Quanta memória usou?

Aqui calculamos o tanto de memória utilizado no processo. Some-ase o valor em memória de todos os atributos do objeto reconstrutor.

In [None]:

def calcular_tamanho_atributos(objeto):
    """Calcula o tamanho de cada atributo de um objeto."""
    tamanho_total = sys.getsizeof(objeto)
    print(f"Tamanho do objeto: {tamanho_total} bytes")

    for atributo in dir(objeto):
        if not atributo.startswith("__"):  # Ignora atributos especiais
            try:
                valor_atributo = getattr(objeto, atributo)
                tamanho_atributo = sys.getsizeof(valor_atributo)
                print(f"  Atributo '{atributo}': {tamanho_atributo/1_048_576:.3f} Megabytes")
                tamanho_total += tamanho_atributo
            except AttributeError:
                pass  # Ignora atributos que não podem ser acessados

    print(f"Tamanho total (aproximado): {tamanho_total/1_048_576:.3f} Megabytes")

print(calcular_tamanho_atributos(tile))


# Montando todos os tiles utilizados

Para um modelo fixo UNetSmall-256-4types-DS-CEW, 4 canais, tanto o treinamento original como o finetune.

Faz a montagem das predições de todos os 12 tiles

In [None]:
all_tiles = {
              'Manaus': '016009',
              'Porto Alegre': '025037',
              'Belo Horizonte': '032027',
              'Salvador': '038019',      

              'Boa Vista': '015002',  
              'Campo Grande': '021027',
              'Macapá': '025005',
              'Curitiba': '027032',
              'Brasília': '028022',                      
              'Rio de Janeiro': '033029',
              'Teresina': '034011',
              'Petrolina': '036016',

              }

in_channels = [4]#, 8]
finetune = [False, True]#, True]

models = [#{'model_name': 'UNet-256-4types-DS-CEW', 'patch_size':256, 'batch_size':16, 'note': ''},
          {'model_name': 'UNetSmall-256-4types-DS-CEW', 'patch_size':256, 'batch_size':16, 'note': ''},]


In [None]:
for tile_id in list(all_tiles.values()):#['032027']:#, '025037', '032027']:
    
    folder = os.path.join(working_dir,f"data/processed/S2-16D_V2_{tile_id}/{num_subtiles}x{num_subtiles}_subtiles")
    files = os.listdir(folder)
    files = [os.path.join(folder, f) for f in files if f.endswith('.tif')]
    for model_dict in models:
        for ch in in_channels:
            for ft in finetune:
                model_name = model_dict['model_name']
                model_name += f'-{ch}ch-4tt'#UNet-256-4types-DS-CE-6ch-4tt
                ckp_file = os.path.join(working_dir, 'models', model_name+'.pth')
                if ft:
                    model_name += f'-finetuned-8ft'#UNet-256-4types-DS-CE-6ch-4tt
                    ckp_file = os.path.join(working_dir, 'models', 'finetuned', model_name+'.pth')
                    
                print('Model name:', model_name)
                if model_name.startswith('UNetSmall-'):
                    model = unets.UNetSmall(in_channels=ch, out_channels=num_classes).to(device) 
                elif model_name.startswith('UNet-'):
                    model = unets.UNet(in_channels=ch, out_channels=num_classes).to(device) 
                elif model_name.startswith('UNetResNet34-'):
                    model = unets.UNetResNet34(in_channels=ch, out_channels=num_classes).to(device) 
                elif model_name.startswith('UNetEfficientNetB0-'):
                    model = unets.UNetEfficientNetB0(in_channels=ch, out_channels=num_classes).to(device) 
                elif model_name.startswith('UNetConvNext-'):
                    model = unets.UNetConvNext(in_channels=ch, out_channels=num_classes).to(device) 
                elif model_name.startswith('HRNetW18'):
                    model = hrnets.HRNetSegmentation(in_channels= ch, num_classes=num_classes, backbone="hrnet_w18_small", pretrained=True,).to(device)
                elif model_name.startswith('HRNetW32'):
                    model = hrnets.HRNetSegmentation(in_channels= ch, num_classes=num_classes, backbone="hrnet_w32", pretrained=True,).to(device)
                elif model_name.startswith('HRNetW48'):
                    model = hrnets.HRNetSegmentation(in_channels= ch, num_classes=num_classes, backbone="hrnet_w48", pretrained=True,).to(device)
                else:
                    print('Nao existe esse modelo')
                    continue
                checkpoint = torch.load(ckp_file, weights_only=False)
                model.load_state_dict(checkpoint['best_model_state_dict'])

                patch_size = model_dict['patch_size']
                stride = patch_size-32
                edge_removal = 8
                if patch_size == 64:
                    stride = patch_size-16
                    edge_removal = 4

                # --------------- opening files -----------------
                folder = os.path.join(working_dir,f"data/processed/S2-16D_V2_{tile_id}/{num_subtiles}x{num_subtiles}_subtiles/q_12ch")
                files = os.listdir(folder)
                files = [os.path.join(folder, f) for f in files if f.endswith('.tif')]
                print(files)
                if len(files)==0:
                    continue
                # --------------- creating a dataloader -----------------
                indices = [i for i, value in enumerate(channels_dict[ch]) if value in channels_dict[12]]
                
                tile_dataset = data.SubtileDataset(files, 
                                                num_subtiles = 6,
                                                classes_mode=classes_mode,
                                                patch_size=patch_size, 
                                                stride=stride, #//2, 
                                                dynamic_sampling = False,
                                                data_augmentation = False,
                                                channels_subset= indices,
                                                return_imgidx = True)
                dataloader = DataLoader(tile_dataset, batch_size=model_dict['batch_size'], shuffle=False)
                
                tile = post.ReconstructTile(patch_size = patch_size, stride = stride, edge_removal=edge_removal, 
                                            num_classes=num_classes, num_channels=ch, tile_id=tile_id)


                import time
                torch.cuda.reset_peak_memory_stats()
                run_time = time.time()

                print('Inferindo...')        
                runner = train.EpochRunner('test', model, dataloader, num_classes=num_classes, 
                                            simulated_batch_size = dataloader.batch_size, device = device)  
                for image, label, logits, pred, x, y, f, in runner.run_generator(show_pred = 0):
                    tile.add_batch(x, y, f, logits, pred, label, image)
                    #print(np.unique(tile.preds))
                print('Montando...')
                tile.set_pred()   

                #print('Pos processamento...')
                #labels, pred_patch, clean_pred, clean_noholes, clean_noholes_2, noholes, noholes2, rules = tile.post_process(0,0)

                #returned = tile.post_process(x, y)
                print('Salvando...')        
                tile.save_pred(tile_id, folder_name = model_name)
                    
                loss, CE, dice, report, acc, cm = runner.get_metrics()
                run_time = time.time()-run_time
                peak_train_memory = f"{torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
                torch.cuda.empty_cache()

                # Replace with actual variable
                
                
                print(f'Test Loss: {loss}, {CE}, {dice}')
                print(f'Test Accuracy: {acc}')
                #print(f'Test confusion matrix:')
                view.plot_single_confusion_matrix(cm)
                print(report)
                
                #### ----------------- Contruct tile
                
                

                r = [0, 10560, 0, 10560]
                #r = [5000, 7000, 1000, 3000]

                plt.figure(figsize=(50, 50))
                plt.subplot(1,2,1)
                plt.imshow(tile.labels[r[0]:r[1], r[2]:r[3]])
                plt.subplot(1,2,2)
                plt.imshow(tile.preds[r[0]:r[1], r[2]:r[3]])
                plt.show()
                
                # pos ptrocessando:                
                tile.post_process_tile()
                tile.save_cleaning(folder_name = model_name)

                del tile
                import gc
                gc.collect()
