# Analyze loss extrema => get loss weights initialization for training

In [1]:
import pandas as pd

## Helper functions

In [2]:
def read_loss(training_id):
    return pd.read_csv(f'/Volumes/DD_FGS/MICS/data_HE2CellType/HE2CT/trainings/{training_id}/train_losses_extrema_before.csv')

In [3]:
def calculate_loss_weights(df_loss_extrema):

    max_values = df_loss_extrema[df_loss_extrema['Scenario'] == 'Max'].iloc[0, 1:]
    min_values = df_loss_extrema[df_loss_extrema['Scenario'] == 'Min'].iloc[0, 1:]
    random_values = df_loss_extrema[df_loss_extrema['Scenario'] == 'Random'].iloc[0, 1:]

    range_values = max_values - min_values

    loss_weigths_range = {col: 1 / range_values[col] if range_values[col] != 0 else 0 for col in range_values.index if col!="Scenario" and col!="total_loss"}
    loss_weights_random = {col: 1 / random_values[col] if random_values[col] != 0 else 0 for col in random_values.index if col!="Scenario" and col!="total_loss"}

    print("Loss weights for range values:")
    for key, value in loss_weigths_range.items():
        print(f"weight_{key.rsplit('_', 1)[0]} = {round(value, 2)}")

    print("\nLoss weights for random values:")
    for key, value in loss_weights_random.items():
        print(f"weight_{key.rsplit('_', 1)[0]} = {round(value, 2)}")
    
    return loss_weigths_range, loss_weights_random

In [4]:
def apply_weights(df_loss_extrema, loss_weights):
    
    weighted_df = df_loss_extrema.copy()
    
    for key, value in loss_weights.items():
        weighted_df[key] = weighted_df[key] * value
    weighted_df['total_loss'] = weighted_df[loss_weights.keys()].sum(axis=1)
    
    return weighted_df

## training_1

In [5]:
training_id = 'training_1'

In [6]:
df_loss_extrema = read_loss(training_id)
df_loss_extrema.round(2)

Unnamed: 0,Scenario,np_ft_loss,np_dice_loss,hv_mse_loss,hv_msge_loss,nt_bce_loss,nt_dice_loss,nt_ft_loss,tissue_ce_loss,total_loss
0,Min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,Max,1.0,2.0,1.0,2.66,16.12,10.0,1.05,16.24,50.07
2,Random,0.4,1.08,0.4,2.03,2.59,9.53,1.0,2.89,19.92


In [7]:
# Before choosing weights (defaults = 1)
loss_weigths_range, loss_weights_random = calculate_loss_weights(df_loss_extrema)

Loss weights for range values:
weight_np_ft = 1.0
weight_np_dice = 0.5
weight_hv_mse = 1.0
weight_hv_msge = 0.38
weight_nt_bce = 0.06
weight_nt_dice = 0.1
weight_nt_ft = 0.95
weight_tissue_ce = 0.06

Loss weights for random values:
weight_np_ft = 2.52
weight_np_dice = 0.92
weight_hv_mse = 2.53
weight_hv_msge = 0.49
weight_nt_bce = 0.39
weight_nt_dice = 0.1
weight_nt_ft = 1.0
weight_tissue_ce = 0.35


In [8]:
# After choosing weights, using range values
weighted_df_range = apply_weights(df_loss_extrema, loss_weigths_range)
weighted_df_range.round(2)

Unnamed: 0,Scenario,np_ft_loss,np_dice_loss,hv_mse_loss,hv_msge_loss,nt_bce_loss,nt_dice_loss,nt_ft_loss,tissue_ce_loss,total_loss
0,Min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,Max,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,8.0
2,Random,0.4,0.54,0.4,0.76,0.16,0.95,0.95,0.18,4.34


## training_2

In [5]:
training_id = 'training_2'

In [6]:
df_loss_extrema = read_loss(training_id)
df_loss_extrema.round(2)

Unnamed: 0,Scenario,np_ft_loss,np_dice_loss,hv_mse_loss,hv_msge_loss,nt_bce_loss,nt_dice_loss,nt_ft_loss,tissue_ce_loss,total_loss
0,Min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,Max,1.0,2.0,1.0,2.46,16.12,10.0,1.05,16.24,49.86
2,Random,0.4,1.06,0.4,1.98,2.59,9.53,1.0,2.89,19.85


In [7]:
# Before choosing weights (defaults = 1)
loss_weigths_range, loss_weights_random = calculate_loss_weights(df_loss_extrema)

Loss weights for range values:
weight_np_ft = 1.0
weight_np_dice = 0.5
weight_hv_mse = 1.0
weight_hv_msge = 0.41
weight_nt_bce = 0.06
weight_nt_dice = 0.1
weight_nt_ft = 0.96
weight_tissue_ce = 0.06

Loss weights for random values:
weight_np_ft = 2.52
weight_np_dice = 0.94
weight_hv_mse = 2.48
weight_hv_msge = 0.5
weight_nt_bce = 0.39
weight_nt_dice = 0.1
weight_nt_ft = 1.0
weight_tissue_ce = 0.35


In [8]:
# After choosing weights, using range values
weighted_df_range = apply_weights(df_loss_extrema, loss_weigths_range)
weighted_df_range.round(2)

Unnamed: 0,Scenario,np_ft_loss,np_dice_loss,hv_mse_loss,hv_msge_loss,nt_bce_loss,nt_dice_loss,nt_ft_loss,tissue_ce_loss,total_loss
0,Min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,Max,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,8.0
2,Random,0.4,0.53,0.4,0.81,0.16,0.95,0.96,0.18,4.38


**==> Choice: range**

## training_3

- Loss extrema: same as training_2

- Results after apply pre-trained model cellvit from the authors on the train set :
    
    - total loss: 2.7592
    
    - nuclei_binary_map_focaltverskyloss: 0.0314
    - nuclei_binary_map_dice: 0.1987
    
    - hv_map_mse: 0.0318
    - hv_map_msge: 0.3388
    
    - nuclei_type_map_bce: 5.7312
    - nuclei_type_map_dice: 9.7729
    - nuclei_type_map_mcfocaltverskyloss: 1.0203
    
    - tissue_types_ce: 2.6194

In [1]:
# Loss weights using pre-trained model on first epoch on train set
print("Loss weights using pre-trained model on first epoch on train set:")
print("weight_np_ft = ", round(1/0.0314, 2), "BUT put 0 for training_3 as we will use 'NTonly' for training")
print("weight_np_dice = ", round(1/0.1987, 2), "BUT put 0 for training_3 as we will use 'NTonly' for training")
print("weight_hv_mse = ", round(1/0.0318, 2), "BUT put 0 for training_3 as we will use 'NTonly' for training")
print("weight_hv_msge = ", round(1/0.3388, 2), "BUT put 0 for training_3 as we will use 'NTonly' for training")
print("weight_nt_bce = ", round(1/5.7312, 2))
print("weight_nt_dice = ", round(1/9.7729, 2))
print("weight_nt_ft = ", round(1/1.0203, 2))
print("weight_tissue_ce = ", round(1/2.6194, 2))

Loss weights using pre-trained model on first epoch on train set:
weight_np_ft =  31.85 BUT put 0 for training_3 as we will use 'NTonly' for training
weight_np_dice =  5.03 BUT put 0 for training_3 as we will use 'NTonly' for training
weight_hv_mse =  31.45 BUT put 0 for training_3 as we will use 'NTonly' for training
weight_hv_msge =  2.95 BUT put 0 for training_3 as we will use 'NTonly' for training
weight_nt_bce =  0.17
weight_nt_dice =  0.1
weight_nt_ft =  0.98
weight_tissue_ce =  0.38
