# Finetune

In [None]:
import torch


In [None]:
def print_gpu_memory(prefix=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024 ** 2)
        reserved = torch.cuda.memory_reserved() / (1024 ** 2)
        print(f"{prefix} Memory Allocated: {allocated:.2f} MB")
        print(f"{prefix} Memory Reserved: {reserved:.2f} MB")
    else:
        print("CUDA is not available.")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache() 

print_gpu_memory()

In [None]:
# imports

import os
import sys
sys.path.append(os.path.abspath('..'))

import src.models.unets as unets
import src.data.preprocess_data as data
import src.training.train_model as train
import src.models.hrnets as hrnets

from torch.utils.data import DataLoader



### É feito no conjunto de 8 tiles, nao usados no treino

In [None]:
tiles_finetune = {
              'Boa Vista': '015002',  
              'Campo Grande': '021027',
              'Macapá': '025005',
              'Curitiba': '027032',
              'Brasília': '028022',                      
              'Rio de Janeiro': '033029',
              'Teresina': '034011',
              'Petrolina': '036016',
              }

tiles = tiles_finetune.values() 
num_subtiles = 6
classes_mode = '4types'
model_types = 'unets'

if model_types=='unets':
    training_batch_size = 16

#model_types = 'unets'

if classes_mode == 'type':
    num_classes = 5
elif classes_mode == '4types':
    num_classes = 4
elif classes_mode == 'density':
    num_classes = 4
elif classes_mode == 'binary':
    num_classes = 2
elif classes_mode == 'all':
    num_classes = 9



channels_dict = {}
channels_dict[12] = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B11', 'B12', 'B8A']
channels_dict[10] = ['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B11', 'B12', 'B8A']
channels_dict[8] = ['B02', 'B03', 'B04', 'B05', 'B06', 'B08', 'B11', 'B12']
channels_dict[6] = ['B02', 'B03', 'B04', 'B06', 'B08', 'B11']
channels_dict[4] = ['B02', 'B03', 'B04','B08']

In [None]:
unet_models_batch = {f'UNetSmall-64-{classes_mode}' : 16, #512,
                     f'UNetSmall-256-{classes_mode}' : 32,
                     f'UNet-64-{classes_mode}': 256,
                     f'UNet-256-{classes_mode}': 16,
                     f'UNetResNet34-224-{classes_mode}': 128, #ok
                     f'UNetEfficientNetB0-224-{classes_mode}': 64, 
                     f'UNetConvNext-224-{classes_mode}': 32,
                     f'HRNetW18-512-{classes_mode}': 4,
                     f'HRNetW32-512-{classes_mode}': 4,
                     f'HRNetW48-512-{classes_mode}': 4
}

In [None]:
if model_types == 'unets':
    model_param_grid = {

        #model params:
        
        'model' : [#f'UNetSmall-64-{classes_mode}',
                   f'UNetSmall-256-{classes_mode}',
               #f'UNet-256-{classes_mode}', #ok
                #f'UNet-64-{classes_mode}', #ok
                #f'UNetResNet34-224-{classes_mode}', #ok
                #f'UNetEfficientNetB0-224-{classes_mode}', 
                #f'UNetConvNext-224-{classes_mode}',
                ],
            
        #training params
            # loss
        'loss': ['CE'], #-dice', 'dice'],#,'groups'],#, 'dice', 'CE-dice'],
        'weighted_loss': [True], #Wegted loss, +CE: bom recall pra 2, 3, 4, ruim resto
        'dist_loss':[False],
        'crf': [False],#[0.0001],    
        'epochs' : [15],
        'patience' : [3],
        'batch_size' : [training_batch_size],
        'dynamic_sampling' : [True],
        'data_augmentation' : [True],
        'num_channels': [8, 4]
    }



In [None]:

train_files, val_files, test_files = data.train_val_test_stratify(tiles, 
                                                                  num_subtiles,
                                                                    train_size = 0.6, 
                                                                    val_size = 0.2, 
                                                                    stratify_by = classes_mode,
                                                                    subfolder='q_12ch')



### Seleção de modelos a serem finetunados

Busca por todos treinados, que sejam UNetSmall, DS-CEW, de 4 ou 8 canais.

In [None]:
working_dir = os.path.abspath('..')
checkpoints = os.listdir(os.path.join(working_dir, 'models'))
checkpoints = [ckp for ckp in checkpoints if ckp.endswith('tt.pth')]
checkpoints = [ckp for ckp in checkpoints if ckp.startswith('UNetSmall')]
checkpoints = [ckp for ckp in checkpoints if 'DS-CEW' in ckp]
#checkpoints = [ckp for ckp in checkpoints if 'CE' in ckp]
checkpoints = [ckp for ckp in checkpoints if ('-4ch-' in ckp or '-8ch-' in ckp)]

print('Checkpoints salvos:')
print(checkpoints)

finetune_epochs = 15
checkpoints = [os.path.join(working_dir, 'models',ckp) for ckp in checkpoints]
#checkpoints = [ckp for ckp in checkpoints if 'UNetSmall-256-4types-DS-CEW-4ch-4tt.pth' in ckp]
checkpoints


### Loop de finetune:

Para cada modelo, carrega o modelo, obtém o checkpoint final, aplica o modelo no conjunto de treino, calcula as métricas, cria um novo dataset com os patches com menos de 50% de macro F-1, e inicia o re-treinamento.

In [None]:
import pandas as pd
for checkpoint_path in checkpoints:
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, weights_only=False)
        metadata = checkpoint['metadata']
        
        model_name = metadata['model_name'] #: 'UNetSmall-256-4types-DS-CE-10ch-4tt'
        print("Modelo: ", model_name)
        print(metadata)
        csv = metadata['file_path']
        df = pd.read_csv(csv)
        #display(df)
        model_name_splitted = model_name.split('-') # ['UNetSmall', '256', '4types', 'DS', 'CE', '12ch', '4tt']
        model_params = {}
        model_params['crf'] = '-crf-' in model_name
        model_params['dist_loss'] = '-dist-' in model_name
        model_params['dynamic_sampling'] = '-DS-' in model_name
        model_params['data_augmentation'] = '-DA-' in model_name
        model_params['model_class'] = model_name_splitted[0]
        model_params['patch_size'] = int(model_name_splitted[1])
        model_params['loss'] = model_name_splitted[-3].removesuffix('W')
        model_params['weighted_loss'] = model_name_splitted[-3].endswith('W')
        model_params['num_channels'] = int(model_name_splitted[-2].removesuffix('ch'))
        print(model_params)

        
            
        
        patch_size = int(model_name.split('-')[1])
        print('--------------------')
        print('Finetuning', model_name)

        finetuned_csv = os.path.join(working_dir, 'experimental_results', 'finetune', model_name+'-8ft.csv')
        finetuned_model = os.path.join(working_dir, 'models', 'finetune', model_name+'-finetuned-8ft.pth')
        
        if os.path.exists(finetuned_csv) and os.path.exists(finetuned_model):
            with open(finetuned_csv, 'r', newline='', encoding='utf-8') as csvfile:
                reader = csv.reader(csvfile)
                line_count = sum(1 for row in reader)
                print(f'Tuned for {line_count} epochs')            
            if line_count > finetune_epochs:
                print('Finetune already completed.')
                continue
        print(model_params)
        model_class = model_params['model_class']
        patch_size = model_params['patch_size']


        # define quais indices
        num_ch = model_params['num_channels']
        indices = [i for i, value in enumerate(channels_dict[num_ch]) if value in channels_dict[12]]

        yaml_filename = data.yaml_filename(num_subtiles, tiles, classes_mode)

        train_dataset = data.SubtileDataset(yaml_filename, 
                                        set = 'train_files',
                                        patch_size=patch_size, 
                                        stride=patch_size,
                                        classes_mode=classes_mode,
                                        channels_subset= indices,
                                        dynamic_sampling = False, #model_params['dynamic_sampling'] ,
                                        data_augmentation = False,# model_params['data_augmentation'], # testando 
                                        return_imgidx = True,
                                        #set = 'finetune_train'

                                        )
        
        val_dataset = data.SubtileDataset(yaml_filename, 
                                        set = 'val_files',
                                        patch_size=patch_size, 
                                        stride=patch_size,
                                        classes_mode=classes_mode,
                                        channels_subset = indices,
                                        dynamic_sampling = False,
                                        data_augmentation = False,
                                        #set = 'finetune_val'
                                        )
        
        test_dataset = data.SubtileDataset(yaml_filename, 
                                        set = 'test_files',
                                        patch_size=patch_size, 
                                        stride=patch_size,
                                        classes_mode=classes_mode,
                                        channels_subset = indices,
                                        dynamic_sampling = False,
                                        data_augmentation = False,
                                        #set = 'finetune_test'
                                        )


        
        if model_params['weighted_loss']:                   
            class_counts, per_image = train_dataset.count_classes()
            class_weights = 1.0 / class_counts  # Inverse of class frequencies
            class_weights = class_weights / torch.sum(class_weights)  # Normalize
        else:    
            class_weights = None

        dynamic_sampling = train_dataset.dynamic_sampling
        data_augmentation = train_dataset.data_augmentation   

        train_loader = DataLoader(train_dataset, batch_size=training_batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=training_batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=training_batch_size, shuffle=False)

        if model_name.startswith('UNetSmall-'):
            model = unets.UNetSmall(in_channels=num_ch, out_channels=num_classes, crf=model_params['crf'], use_dist = model_params['dist_loss']).to(device) 
        elif model_name.startswith('UNet-'):
            model = unets.UNet(in_channels=num_ch, out_channels=num_classes, crf=model_params['crf']).to(device) 
        elif model_name.startswith('UNetResNet34-'):
            model = unets.UNetResNet34(in_channels=num_ch, out_channels=num_classes, crf=model_params['crf']).to(device) 
        elif model_name.startswith('UNetEfficientNetB0-'):
            model = unets.UNetEfficientNetB0(in_channels=num_ch, out_channels=num_classes, crf=model_params['crf']).to(device) 
        elif model_name.startswith('UNetConvNext-'):
            model = unets.UNetConvNext (in_channels=num_ch, out_channels=num_classes, crf=model_params['crf']).to(device) 
        elif model_name.startswith('HRNetW18'):
            model = hrnets.HRNetSegmentation(in_channels= num_ch, num_classes=num_classes, backbone="hrnet_w18_small", pretrained=True,).to(device)
        elif model_name.startswith('HRNetW32'):
            model = hrnets.HRNetSegmentation(in_channels= num_ch, num_classes=num_classes, backbone="hrnet_w32", pretrained=True,).to(device)
        elif model_name.startswith('HRNetW48'):
            model = hrnets.HRNetSegmentation(in_channels= num_ch, num_classes=num_classes, backbone="hrnet_w48", pretrained=True,).to(device)
        else:
            print(f'Modelo {model_name} não está no param grid. Pulando...')
            continue


        ### Inferencia:
        working_dir = os.path.abspath('..')
        checkpoint_path_ = os.path.join(working_dir, 'models', checkpoint_path)
        if os.path.exists(checkpoint_path_):
            checkpoint = torch.load(checkpoint_path_, weights_only=False)
            metadata = checkpoint['metadata']
            #print(checkpoint_path)
            #print(checkpoint)
            model.load_state_dict(checkpoint['best_model_state_dict'])
        else:
            raise ValueError(f"Erro na leitura do modelo {checkpoint_path_}.")
        criterion = train.CombinedLoss(loss_mode = model_params['loss'], weights = None, return_all_losses=True)

        runner = train.EpochRunner('test', model, train_loader, criterion, num_classes=num_classes, 
                                    optimizer=None, simulated_batch_size = train_loader.batch_size, device = device)

        patches = []   
        ious = [] 
        threshold = 0.8
        counter = 0
        for image, label, logits, pred, x, y, f, in runner.run_generator(show_pred = 1):
            x, y, f, logits, pred, label, image
            #for bi in range(pred.shape[0]):
            IOU, mean_IOU, macro_IOU = train.compute_iou_per_sample(pred,label,num_classes=num_classes)
            counter+=pred.shape[0]
            for i in range(pred.shape[0]):
                #TODO: ver se o label tem classe 1 e 2. 
                
                pct_loteamento = (label == 1).sum().item()/label.numel()
                pct_equipamentos = (label == 2).sum().item()/label.numel()
                pct_AU = (label != 0).sum().item()/label.numel()
                
                if macro_IOU[i]<=threshold and pct_loteamento+pct_equipamentos >= 0.01:
                    print(macro_IOU[i])
                    print(IOU[i])
                    print(pct_loteamento, pct_equipamentos)

                    idx_dict = {'file':f[i], 
                                'x':x[i].item(), 
                                'y':y[i].item(), 
                                'transform' : 0,
                                'step_shift' : 0
                                }               
                    patches.append(idx_dict)
                    ious.append((IOU, mean_IOU, macro_IOU, pct_loteamento, pct_equipamentos))
            #if len(patches)>0:
            #    break
        print('TOTAL SELECTED PATCHES:', len(patches))
        print('TOTAL PATCHES:', counter)
        print('PERCENTUAL SELECTED PATCHES:', 100*len(patches)/counter)
        print(patches[0])
        print(ious[0])

        if 1:
            loss, CE, dice, report, acc, cm = runner.get_metrics()
            peak_val_memory = f"{torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
            torch.cuda.empty_cache()
            print('_____________________________________')
            print(checkpoint_path)
            print(f'Loss: {loss}, {CE}, {dice}')
            print(f'Accuracy: {acc}')
            print(f'confusion matrix:')
                    

        selection_train_dataset = data.SubtileDataset(patches, 
                                        set = 'train_files',
                                        patch_size=patch_size, 
                                        stride=patch_size,
                                        classes_mode=classes_mode,
                                        channels_subset= indices,
                                        dynamic_sampling = True, #model_params['dynamic_sampling'] ,
                                        data_augmentation = True, #model_params['data_augmentation'], # testando 
                                        return_imgidx = True,
                                        #set = 'finetune_selection_train'
                                        )
        selection_train_loader = DataLoader(selection_train_dataset, batch_size=training_batch_size, shuffle=True)
        print('Tamanho do novo dataset para o finetune: ', len(selection_train_loader))
        print(selection_train_dataset.count_classes())

        pixel_count, subtile_count = selection_train_dataset.count_classes()
        weights = [torch.sum(pixel_count)/c for c in pixel_count]
        weights = [w/sum(weights) for w in weights]
        
        finetune_params = {'epochs': finetune_epochs,
                           'loss_mode': 'CE',
                           'patience': 3,
                           'weights': weights, # [0.1, 1, 1, 0.5],
                           'save_to': model_name+f'-finetuned-{len(tiles)}ft.pth',
                           'maximum_lr':0.1 
                           }
        print('Treinando...')
        print('pesos:', weights)
        train.train_model(model, 
                        selection_train_loader, 
                        val_loader, 
                        epochs=finetune_params['epochs'], 
                        loss_mode = finetune_params['loss_mode'],
                        device = device,
                        num_classes = num_classes, 
                        simulated_batch_size = training_batch_size, #model_params['batch_size'] ,
                        patience = finetune_params['patience'],
                        weights = finetune_params['weights'],
                        show_batches = 1, 
                        save_to = finetune_params['save_to'],
                        save_subfolder= 'finetuned',
                        maximum_lr=finetune_params['maximum_lr']) #o padrao é 0.1, entao esse é 10x menos.
        
