# Treinamento

Carrega os imports, define os tiles, define os parametros de dados e de modelos

In [None]:
import torch


In [None]:
def print_gpu_memory(prefix=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024 ** 2)
        reserved = torch.cuda.memory_reserved() / (1024 ** 2)
        print(f"{prefix} Memory Allocated: {allocated:.2f} MB")
        print(f"{prefix} Memory Reserved: {reserved:.2f} MB")
    else:
        print("CUDA is not available.")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache() 

print_gpu_memory()

In [None]:
# imports

import os
import sys
sys.path.append(os.path.abspath('..'))

import src.models.unets as unets
import src.data.preprocess_data as data
import src.training.train_model as train
import src.models.hrnets as hrnets

from torch.utils.data import DataLoader



In [None]:

tiles_1 = {
              'Belo Horizonte': '032027',
              }

tiles_4 = {
              'Manaus': '016009',
              'Porto Alegre': '025037',
              'Belo Horizonte': '032027',
              'Salvador': '038019',      
              }
tiles_8 = {
              'Boa Vista': '015002',  
              'Campo Grande': '021027',
              'Macapá': '025005',
              'Curitiba': '027032',
              'Brasília': '028022',                      
              'Rio de Janeiro': '033029',
              'Teresina': '034011',
              'Petrolina': '036016',
              }

tiles = {}
tiles['1 tile'] = list(tiles_1.values())
tiles['4 tiles'] = list(tiles_4.values())
tiles['8 tiles'] = list(tiles_8.values())
tiles['12 tiles'] = list(set(list(tiles_4.values())+list(tiles_8.values())))



tiles = tiles['8 tiles'] 
num_subtiles = 6
classes_mode = '4types'
model_types = 'unets'

if model_types=='unets':
    training_batch_size = 16

#model_types = 'unets'

if classes_mode == 'type':
    num_classes = 5
elif classes_mode == '4types':
    num_classes = 4
elif classes_mode == 'density':
    num_classes = 4
elif classes_mode == 'binary':
    num_classes = 2
elif classes_mode == 'all':
    num_classes = 9



channels_dict = {}
channels_dict[12] = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B11', 'B12', 'B8A']
channels_dict[10] = ['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B11', 'B12', 'B8A']
channels_dict[8] = ['B02', 'B03', 'B04', 'B05', 'B06', 'B08', 'B11', 'B12']
channels_dict[6] = ['B02', 'B03', 'B04', 'B06', 'B08', 'B11']
channels_dict[4] = ['B02', 'B03', 'B04','B08']

In [None]:
unet_models_batch = {f'UNetSmall-64-{classes_mode}' : 16, #512,
                     f'UNetSmall-256-{classes_mode}' : 32,
                     f'UNet-64-{classes_mode}': 256,
                     f'UNet-256-{classes_mode}': 16,
                     f'UNetResNet34-224-{classes_mode}': 128, #ok
                     f'UNetEfficientNetB0-224-{classes_mode}': 64, 
                     f'UNetConvNext-224-{classes_mode}': 32,
                     f'HRNetW18-512-{classes_mode}': 4,
                     f'HRNetW32-512-{classes_mode}': 4,
                     f'HRNetW48-512-{classes_mode}': 4
}

In [None]:
if model_types == 'unets':
    model_param_grid = {

        #model params:
        
        'model' : [#f'UNetSmall-64-{classes_mode}',
                   f'UNetSmall-256-{classes_mode}',
               #f'UNet-256-{classes_mode}', #ok
                #f'UNet-64-{classes_mode}', #ok
                #f'UNetResNet34-224-{classes_mode}', #ok
                #f'UNetEfficientNetB0-224-{classes_mode}', 
                #f'UNetConvNext-224-{classes_mode}',
                ],
            
        #training params
            # loss
        'loss': ['CEW'],#macroF2', 'macroF2W'],#['CE', 'CEW'], #-dice', 'dice'],#,'groups'],#, 'dice', 'CE-dice'],
        'epochs' : [15],
        'patience' : [3],
    }
data_param_grid = {
        'batch_size' : [training_batch_size],
        'dynamic_sampling' : [True],
        'data_augmentation' : [False],
        'num_channels': [8, 4],#[12, 10, 8, 6, 4],
        'patch_size': [256]
    }



In [None]:

train_files, val_files, test_files = data.train_val_test_stratify(tiles, 
                                                                  num_subtiles,
                                                                    train_size = 0.6, 
                                                                    val_size = 0.2, 
                                                                    stratify_by = classes_mode,
                                                                    subfolder='q_12ch')



## Loop de treino

Em 2 loops: sobre os parametros dos dados, e sobre os parametros dos modelos

Varia sobre todas as combinações de parametros dos dois.

In [None]:
for data_params in train.iterate_parameter_grid(data_param_grid):

    num_ch = data_params['num_channels']
    indices = [i for i, value in enumerate(channels_dict[num_ch]) if value in channels_dict[12]]

    yaml_filename = data.yaml_filename(num_subtiles, tiles, classes_mode)
    
    #print('Trying to read from:', yaml_filename)
    train_dataset = data.SubtileDataset(yaml_filename, 
                                    set = 'train_files',
                                    patch_size=data_params['patch_size'], 
                                    stride=data_params['patch_size'],
                                    classes_mode=classes_mode,
                                    channels_subset= indices,
                                    dynamic_sampling = data_params['dynamic_sampling'] ,
                                    data_augmentation = data_params['data_augmentation'], # testando 
                                    )
    
    val_dataset = data.SubtileDataset(yaml_filename, 
                                    set = 'val_files',
                                    patch_size=data_params['patch_size'], 
                                    stride=data_params['patch_size'],
                                    classes_mode=classes_mode,
                                    channels_subset = indices,
                                    dynamic_sampling = False,
                                    data_augmentation = False, # testando 
                                    )
    
    test_dataset = data.SubtileDataset(yaml_filename, 
                                    set = 'test_files',
                                    patch_size=data_params['patch_size'], 
                                    stride=data_params['patch_size'],
                                    classes_mode=classes_mode,
                                    channels_subset = indices,
                                    dynamic_sampling = False,
                                    data_augmentation = False, # testando 
                                    )

    for model_params in train.iterate_parameter_grid(model_param_grid):

        #definitions
        #if classes_mode == 'binary' and model_params['weighted_loss']: #duas classes, sem weighted loss
        #    continue

        
        model_name = model_params['model']
        training_batch_size = min(16, unet_models_batch[model_name])
        if data_params['dynamic_sampling']:
            model_name+='-DS'
        if data_params['data_augmentation']:
            model_name+='-DA'
        model_name += f'-{model_params["loss"]}'
        model_name+=f"-{data_params['num_channels']}ch" #nof channels
        model_name+=f"-{len(tiles)}tt" #nof train tiles
            
        weighted_loss = model_params["loss"].endswith('W')
        
        patch_size = int(model_name.split('-')[1])
        print('--------------------')
        print('Training', model_name)
        print(model_params)
        model_class = model_name.split('-')[0]
        patch_size = int(model_name.split('-')[1])

        if not weighted_loss and not data_params['dynamic_sampling']:
            continue #ignore CE
        #load data

        # define quais indices

        if weighted_loss:                   
            class_counts, per_image = train_dataset.count_classes()
            class_weights = 1.0 / class_counts  # Inverse of class frequencies
            class_weights = class_weights / torch.sum(class_weights)  # Normalize
        else:    
            class_weights = None

        dynamic_sampling = train_dataset.dynamic_sampling
        data_augmentation = train_dataset.data_augmentation   

        train_loader = DataLoader(train_dataset, batch_size=training_batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=training_batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=training_batch_size, shuffle=False)

        if model_name.startswith('UNetSmall-'):
            model = unets.UNetSmall(in_channels=num_ch, out_channels=num_classes).to(device) 
        elif model_name.startswith('UNet-'):
            model = unets.UNet(in_channels=num_ch, out_channels=num_classes).to(device) 
        elif model_name.startswith('UNetResNet34-'):
            model = unets.UNetResNet34(in_channels=num_ch, out_channels=num_classes).to(device) 
        elif model_name.startswith('UNetEfficientNetB0-'):
            model = unets.UNetEfficientNetB0(in_channels=num_ch, out_channels=num_classes).to(device) 
        elif model_name.startswith('UNetConvNext-'):
            model = unets.UNetConvNext (in_channels=num_ch, out_channels=num_classes).to(device) 
        elif model_name.startswith('HRNetW18'):
            model = hrnets.HRNetSegmentation(in_channels= num_ch, num_classes=num_classes, backbone="hrnet_w18_small", pretrained=True,).to(device)
        elif model_name.startswith('HRNetW32'):
            model = hrnets.HRNetSegmentation(in_channels= num_ch, num_classes=num_classes, backbone="hrnet_w32", pretrained=True,).to(device)
        elif model_name.startswith('HRNetW48'):
            model = hrnets.HRNetSegmentation(in_channels= num_ch, num_classes=num_classes, backbone="hrnet_w48", pretrained=True,).to(device)
        else:
            print(f'Modelo {model_name} não está no param grid. Pulando...')
            continue


        print(model_params['loss'])
        train.train_model(model, 
                            train_loader, 
                            val_loader, 
                            epochs=model_params['epochs'], 
                            loss_mode = model_params['loss'].replace('W',''),
                            device = device,
                            num_classes = num_classes, 
                            simulated_batch_size = training_batch_size, #model_params['batch_size'] ,
                            patience = model_params['patience'],
                            weights = class_weights,
                            show_batches = 1, 
                            save_to = model_name+'.pth')
        if 0:
            try:
                train.test_model(model, 
                            checkpoint_path=model_name+'.pth',
                            dataloader = test_loader, 
                            device = device, 
                            num_classes = num_classes
                            ) 
                            #loss_mode = model_params['loss'], 
                            #simulated_batch_size = model_params['batch_size'] ,
                            #show_batches = 3,
                            #yield_predictions = True)
            except:
                print('ERROR IN TESTING')

            


# Fim do treinamento