# Treinamento dos modelos

In [None]:
import torch


In [None]:
def print_gpu_memory(prefix=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024 ** 2)
        reserved = torch.cuda.memory_reserved() / (1024 ** 2)
        print(f"{prefix} Memory Allocated: {allocated:.2f} MB")
        print(f"{prefix} Memory Reserved: {reserved:.2f} MB")
    else:
        print("CUDA is not available.")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache() 

print_gpu_memory()

In [None]:
# imports

import os
import sys
sys.path.append(os.path.abspath('..'))

import src.models.unets as unets
import src.data.preprocess_data as data
import src.training.train_model as train
import src.models.hrnets as hrnets

from torch.utils.data import DataLoader



### Definições:

defini quais tiles, divisão de subtiles, tipo de modelos e tipos de classes são alvo do treino

In [None]:

tiles = ['032027']#, '032026'] 
num_subtiles = 6
classes_mode = 'type'
model_types = 'unets'

if model_types=='hrnets':
    training_batch_size = 4
if model_types=='unets':
    training_batch_size = 16

#model_types = 'unets'

if classes_mode == 'type':
    num_classes = 5
elif classes_mode == 'density':
    num_classes = 4
elif classes_mode == 'binary':
    num_classes = 2
elif classes_mode == 'all':
    num_classes = 9





Aqui definimos o batch size máximo para cada modelo

In [None]:
unet_models_batch = {f'UNetSmall-64-{classes_mode}' : 16, #512,
                     f'UNetSmall-256-{classes_mode}' : 32,
                     f'UNet-64-{classes_mode}': 256,
                     f'UNet-256-{classes_mode}': 16,
                     f'UNetResNet34-224-{classes_mode}': 128, #ok
                     f'UNetEfficientNetB0-224-{classes_mode}': 64, 
                     f'UNetConvNext-224-{classes_mode}': 32,
                     f'HRNetW18-512-{classes_mode}': 4,
                     f'HRNetW32-512-{classes_mode}': 4,
                     f'HRNetW48-512-{classes_mode}': 4
}

## Grade de parametros

Muitos variações dos modelos vão ser treinados. Aqui definimos quais variações a considerar, entre elas, modelos, tipo de loss, se utilizar ponderação, amostragem dinâmica, etc.



In [None]:
if model_types == 'unets':
    model_param_grid = {

        #model params:
        
        'model' : [f'UNetSmall-64-{classes_mode}',
                   f'UNetSmall-256-{classes_mode}',
               f'UNet-256-{classes_mode}', #ok
                f'UNet-64-{classes_mode}', #ok
                f'UNetResNet34-224-{classes_mode}', #ok
                f'UNetEfficientNetB0-224-{classes_mode}', 
                f'UNetConvNext-224-{classes_mode}',
                ],
            
        #training params
            # loss
        'loss': ['CE'], #-dice', 'dice'],#,'groups'],#, 'dice', 'CE-dice'],
        'weighted_loss': [False, True], #Wegted loss, +CE: bom recall pra 2, 3, 4, ruim resto
        'dist_loss':[False],
        'crf': [False],#[0.0001],    
        'epochs' : [15],
        'patience' : [3],
        'batch_size' : [training_batch_size],
        'dynamic_sampling' : [False, True],
        'data_augmentation' : [False],
        
    }

if model_types == 'hrnets':
    model_param_grid = {
        #model params:
        'model' : [
                #f'HRNetW18-1024-{classes_mode}',
                #f'HRNetW32-1024-{classes_mode}',
                #f'HRNetW48-1024-{classes_mode}'
                f'HRNetW18-512-{classes_mode}',
                f'HRNetW32-512-{classes_mode}',
                f'HRNetW48-512-{classes_mode}'
                ],
            
        #training params
            # loss
        'loss': ['CE'], #-dice', 'dice'],#,'groups'],#, 'dice', 'CE-dice'],
        'weighted_loss': [False, True], #Wegted loss, +CE: bom recall pra 2, 3, 4, ruim resto
        'dist_loss':[False],
        'crf': [False],#[0.0001],    
        'epochs' : [15],
        'patience' : [3],
        'batch_size' : [training_batch_size],
        'dynamic_sampling' : [False, True],
        'data_augmentation' : [False],
    }



Separação em treino validação e teste, com estratificação.

In [None]:

train_files, val_files, test_files = data.train_val_test_stratify(tiles, 
                                                                  num_subtiles,
                                                                    train_size = 0.6, 
                                                                    val_size = 0.2, 
                                                                    stratify_by = classes_mode)



## Loop de treino:

Varre o grade de parâmetros, carrega o dataset correspondente, instancia o modelo e treina.

Os modelos ficam salvos em models

Os resultados e métricas de treino ficam salvos em experimental_results


In [None]:
for model_params in train.iterate_parameter_grid(model_param_grid):

    
    model_name = model_params['model']
    training_batch_size = min(16, unet_models_batch[model_name])
    if model_params['crf']:
        model_name+='-crf'
    if model_params['dist_loss']:
        model_name+='-dist'
    if model_params['dynamic_sampling']:
        model_name+='-DS'
    if model_params['data_augmentation']:
        model_name+='-DA'
    model_name += f'-{model_params["loss"]}'
    if model_params['weighted_loss']:
        model_name+='W'
    patch_size = int(model_name.split('-')[1])
    print('--------------------')
    print('Training', model_name)
    print(model_params)
    model_class = model_name.split('-')[0]
    patch_size = int(model_name.split('-')[1])

    if 0:
        if model_params['weighted_loss'] and (model_params['data_augmentation'] or model_params['dynamic_sampling']):
            print('Weighted loss: True and some type of data augmentation. It is setup to disconsider this combination.')
            print('skipping...')
            continue
    #load data

    yaml_filename = data.yaml_filename(num_subtiles, tiles, classes_mode)
    train_dataset = data.SubtileDataset(yaml_filename, 
                                    set = 'train_files',
                                    patch_size=patch_size, 
                                    stride=patch_size, 
                                    dynamic_sampling = model_params['dynamic_sampling'] ,
                                    data_augmentation = model_params['data_augmentation'], # testando 
                                    )
    
    val_dataset = data.SubtileDataset(yaml_filename, 
                                    set = 'val_files',
                                    patch_size=patch_size, 
                                    stride=patch_size, 
                                    dynamic_sampling = False,
                                    data_augmentation = False, # testando 
                                    )
    
    test_dataset = data.SubtileDataset(yaml_filename, 
                                    set = 'test_files',
                                    patch_size=patch_size, 
                                    stride=patch_size, 
                                    dynamic_sampling = False,
                                    data_augmentation = False, # testando 
                                    )

    if model_params['weighted_loss']:                   
        class_counts, per_image = train_dataset.count_classes()
        class_weights = 1.0 / class_counts  # Inverse of class frequencies
        class_weights = class_weights / torch.sum(class_weights)  # Normalize
    else:    
        class_weights = None

    dynamic_sampling = train_dataset.dynamic_sampling
    data_augmentation = train_dataset.data_augmentation   

    train_loader = DataLoader(train_dataset, batch_size=training_batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=training_batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=training_batch_size, shuffle=False)

    if model_name.startswith('UNetSmall-'):
        model = unets.UNetSmall(in_channels=12, out_channels=num_classes, crf=model_params['crf'], use_dist = model_params['dist_loss']).to(device) 
    elif model_name.startswith('UNet-'):
        model = unets.UNet(in_channels=12, out_channels=num_classes, crf=model_params['crf']).to(device) 
    elif model_name.startswith('UNetResNet34-'):
        model = unets.UNetResNet34(in_channels=12, out_channels=num_classes, crf=model_params['crf']).to(device) 
    elif model_name.startswith('UNetEfficientNetB0-'):
        model = unets.UNetEfficientNetB0(in_channels=12, out_channels=num_classes, crf=model_params['crf']).to(device) 
    elif model_name.startswith('UNetConvNext-'):
        model = unets.UNetConvNext (in_channels=12, out_channels=num_classes, crf=model_params['crf']).to(device) 
    elif model_name.startswith('HRNetW18'):
        model = hrnets.HRNetSegmentation(in_channels= 12, num_classes=num_classes, backbone="hrnet_w18_small", pretrained=True,).to(device)
    elif model_name.startswith('HRNetW32'):
        model = hrnets.HRNetSegmentation(in_channels= 12, num_classes=num_classes, backbone="hrnet_w32", pretrained=True,).to(device)
    elif model_name.startswith('HRNetW48'):
        model = hrnets.HRNetSegmentation(in_channels= 12, num_classes=num_classes, backbone="hrnet_w48", pretrained=True,).to(device)
    else:
        print(f'Modelo {model_name} não está no param grid. Pulando...')
        continue


    print(model_params['loss'])
    train.train_model(model, 
                        train_loader, 
                        val_loader, 
                        epochs=model_params['epochs'], 
                        loss_mode = model_params['loss'],
                        device = device,
                        num_classes = num_classes, 
                        simulated_batch_size = training_batch_size, #model_params['batch_size'] ,
                        patience = model_params['patience'],
                        weights = class_weights,
                        show_batches = 1, 
                        save_to = model_name+'.pth')
    try:
        train.test_model(model, 
                     checkpoint_path=model_name+'.pth',
                     dataloader = test_loader, 
                     device = device, 
                     num_classes = num_classes
                     ) 
                     #loss_mode = model_params['loss'], 
                     #simulated_batch_size = model_params['batch_size'] ,
                     #show_batches = 3,
                     #yield_predictions = True)
    except:
        print('ERROR IN TESTING')

    
