# Calculate weights for FT loss and tissues

In [1]:
import pandas as pd
import json
import numpy as np

In [2]:
# Option more advanced for class weights for FT loss for NT branch (background in normalization + ignore cat with a small weight)
def compute_class_weights(
    ds_train: pd.DataFrame,
    cell_type_cols: list,
    ignore_cat: list,           # e.g. ["unrelevant_class1", "unrelevant_class2"]
    fraction_unrelevant=0.1,    # fraction of the min relevant weight to assign to unrelevant classes
    fraction_background=1.0,    # fraction for background weight relative to log-based approach or min relevant
) -> list:
    """
    Compute a final list of class weights for a multi-class segmentation/classification task.
    The pipeline:
      1) Compute frequency-based log-weights for relevant classes.
      2) Set unrelevant classes to a fraction of the min relevant weight.
      3) Handle background class either via log formula or fraction-based approach.
      4) Normalize so sum of weights = 1.
    
    Returns:
      A list of weights in the order:
        [ background, class_1, class_2, ..., class_n ]
      where 'background' is inserted at index 0, and the rest follow the order given in cell_type_cols.
    """

    # Compute baseline frequencies for all classes
    cell_type_frequencies = ds_train[cell_type_cols].sum() / ds_train[cell_type_cols].values.sum()

    # Classify columns into relevant, unrelevant, and background
    relevant_cats   = [c for c in cell_type_cols if c not in ignore_cat]
    unrelevant_cats = [c for c in cell_type_cols if c in ignore_cat]

    # Compute log-based weights for relevant classes: w_c = log(1 + 1/freq_c)
    weights = {}
    for cat in relevant_cats:
        weights[cat] = np.log(1.0 + 1.0 / cell_type_frequencies[cat])

    # Determine minimal relevant weight
    if len(relevant_cats) > 0:
        min_relevant_weight = min(weights[c] for c in relevant_cats)
    else:
        # Edge case: if no relevant classes remain, fallback to 1.0
        min_relevant_weight = 1.0

    # Assign a small fraction of min_relevant_weight to unrelevant classes
    for cat in unrelevant_cats:
        weights[cat] = fraction_unrelevant * min_relevant_weight

    # Handle background weight
    # Set it relative to min_relevant_weight
    weights["Background"] = fraction_background * min_relevant_weight

    # Build final weight vector in the order: [ background, class_1, class_2, ... ]
    final_weights = []
    for cat in ["Background"] + cell_type_cols:
        if cat not in weights:
            raise ValueError(f"Category '{cat}' not found in computed weights.")
        else:
            final_weights.append(weights[cat])

    # Normalize so sum of weights = 1
    sum_w = sum(final_weights)
    if sum_w > 0:
        final_weights = [w / sum_w for w in final_weights]
    else:
        raise ValueError("Sum of final weights is zero. Check for errors in the computation.")

    # Print
    print("WEIGHTS FOR FT LOSS USING LOGARITHMIC SCALING:")
    print(["Background"] + cell_type_cols)
    print([round(w, 3) for w in final_weights])

## ds_1

In [3]:
dataset_id = 'ds_1'
cell_cat_id = 'ct_1'
ignore_cat = [] # Define the list of cell types to ignore in the loss

In [5]:
with open(f"/Volumes/DD1_FGS/MICS/data_HE2CellType/CT_DS/annots/annot_dicts_{cell_cat_id}/cat2color.json", "r") as f:
    cat2color = json.load(f)

cell_type_cols = cell_type_cols = list(cat2color.keys()) # do not provide "Background"
cell_type_cols

['T_NK',
 'B_Plasma',
 'Myeloid',
 'Blood_vessel',
 'Fibroblast_Myofibroblast',
 'Epithelial',
 'Specialized',
 'Melanocyte',
 'Dead']

In [6]:
ds_infos = pd.read_csv(f'/Volumes/DD1_FGS/MICS/data_HE2CellType/HE2CT/training_datasets/{dataset_id}/informations/infos_{dataset_id}.csv')
ds_train = ds_infos[ds_infos['set'] == 'train']
ds_train

Unnamed: 0,T_NK,B_Plasma,Myeloid,Blood_vessel,Fibroblast_Myofibroblast,Epithelial,Specialized,Melanocyte,Dead,img,type,Dice,Jaccard,bPQ,slide_id,set
0,0,0,1,0,1,0,2,0,0,heart_s0_362.png,Heart,0.777565,0.636079,0.480884,heart_s0,train
1,0,0,0,2,1,0,4,0,0,heart_s0_363.png,Heart,0.709650,0.549967,0.378748,heart_s0,train
2,0,0,0,1,2,0,1,0,0,heart_s0_364.png,Heart,0.813963,0.686289,0.603601,heart_s0,train
3,0,0,0,2,2,0,0,0,0,heart_s0_365.png,Heart,0.833883,0.715094,0.647057,heart_s0,train
6,0,0,0,1,0,0,0,0,0,heart_s0_437.png,Heart,0.920128,0.852071,0.852069,heart_s0,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
508475,0,0,0,0,0,6,0,0,33,cervix_s0_46410.png,Cervix,0.711000,0.551590,0.358931,cervix_s0,train
508476,0,0,0,0,0,20,0,0,46,cervix_s0_46411.png,Cervix,0.726411,0.570366,0.332126,cervix_s0,train
508477,0,0,0,0,0,26,0,0,24,cervix_s0_46412.png,Cervix,0.678762,0.513731,0.348578,cervix_s0,train
508479,0,0,0,0,0,18,0,0,2,cervix_s0_46415.png,Cervix,0.706722,0.546458,0.462859,cervix_s0,train


In [7]:
# Compute the pixel-wise frequency of each class
cell_type_frequencies = ds_train[cell_type_cols].sum() / ds_train[cell_type_cols].values.sum()

# Apply the logarithmic scaling to compute weights
cell_type_weights = np.log(1 + (1 / cell_type_frequencies))

# Set weights to zero for ignored categories
for cat in ignore_cat:
    if cat in cell_type_cols:
        cell_type_weights[cell_type_cols.index(cat)] = 0

# Normalize the weights (excluding the background for now)
cell_type_weights /= cell_type_weights.sum()

cell_type_weights

T_NK                        0.080978
B_Plasma                    0.095278
Myeloid                     0.086617
Blood_vessel                0.103765
Fibroblast_Myofibroblast    0.084204
Epithelial                  0.053929
Specialized                 0.218707
Melanocyte                  0.163323
Dead                        0.113197
dtype: float64

In [8]:
# Convert to a list for further processing
list_cell_type_weights = cell_type_weights.tolist()

# Add background weight at the beginning (minimum weight from normalized weights)
background_weight = min([w for w in list_cell_type_weights if w > 0])
list_cell_type_weights.insert(0, background_weight)

# Round the weights to the desired precision
list_cell_type_weights = [round(w, 3) for w in list_cell_type_weights]

print("WEIGHTS FOR FT LOSS USING LOGARITHMIC SCALING:")
print(list_cell_type_weights)

WEIGHTS FOR FT LOSS USING LOGARITHMIC SCALING:
[0.054, 0.081, 0.095, 0.087, 0.104, 0.084, 0.054, 0.219, 0.163, 0.113]


In [9]:
print("NUMBER OF CLASSES INCLUDING BACKGROUND:")
print(len(list_cell_type_weights)) # including the background

NUMBER OF CLASSES INCLUDING BACKGROUND:
10


In [10]:
print("CLASSES IN RIGHT ORDER WITH BACKGROUND FIRST:")
print(["Background"] + cell_type_cols)

CLASSES IN RIGHT ORDER WITH BACKGROUND FIRST:
['Background', 'T_NK', 'B_Plasma', 'Myeloid', 'Blood_vessel', 'Fibroblast_Myofibroblast', 'Epithelial', 'Specialized', 'Melanocyte', 'Dead']


In [11]:
print("COUNT OF PATCHES FOR EACH TISSUE TYPE (ALPHABETICAL ORDER):")
ds_train['type'].value_counts().sort_index()

COUNT OF PATCHES FOR EACH TISSUE TYPE (ALPHABETICAL ORDER):


Breast        131868
Cervix         21296
Colon          12921
Heart           2888
Kidney          6604
Liver          17269
Lung           16566
LymphNode       7581
Ovarian        18476
Pancreatic     22330
Prostate        7355
Skin           16422
Tonsil         23505
Name: type, dtype: int64

In [12]:
print("NUMBER OF TISSUES:")
print(len(ds_train['type'].unique()))

NUMBER OF TISSUES:
13


In [13]:
compute_class_weights(
    ds_train=ds_train,
    cell_type_cols=cell_type_cols,
    ignore_cat=["Dead"],
    fraction_unrelevant=0.1,    # fraction of the min relevant weight to assign to unrelevant classes
    fraction_background=1.0)

WEIGHTS FOR FT LOSS USING LOGARITHMIC SCALING:
['Background', 'T_NK', 'B_Plasma', 'Myeloid', 'Blood_vessel', 'Fibroblast_Myofibroblast', 'Epithelial', 'Specialized', 'Melanocyte', 'Dead']
[0.057, 0.086, 0.101, 0.092, 0.11, 0.089, 0.057, 0.231, 0.173, 0.006]


## ds_2

In [2]:
dataset_id = 'ds_2'
cell_cat_id = 'ct_1'
ignore_cat = [] # Define the list of cell types to ignore in the loss

In [None]:
with open(f"/Volumes/DD1_FGS/MICS/data_HE2CellType/CT_DS/annots/annot_dicts_{cell_cat_id}/cat2color.json", "r") as f:
    cat2color = json.load(f)

cell_type_cols = cell_type_cols = list(cat2color.keys()) # do not provide "Background"
cell_type_cols

['T_NK',
 'B_Plasma',
 'Myeloid',
 'Blood_vessel',
 'Fibroblast_Myofibroblast',
 'Epithelial',
 'Specialized',
 'Melanocyte',
 'Dead']

In [None]:
ds_infos = pd.read_csv(f'/Volumes/DD1_FGS/MICS/data_HE2CellType/HE2CT/training_datasets/{dataset_id}/informations/infos_{dataset_id}.csv')
ds_train = ds_infos[ds_infos['set'] == 'train']
ds_train

Unnamed: 0,T_NK,B_Plasma,Myeloid,Blood_vessel,Fibroblast_Myofibroblast,Epithelial,Specialized,Melanocyte,Dead,img,type,Dice,Jaccard,bPQ,slide_id,set
0,0,0,0,2,2,0,0,0,0,heart_s0_365.png,Heart,0.833883,0.715094,0.647057,heart_s0,train
2,0,0,0,0,0,0,3,0,0,heart_s0_480.png,Heart,0.851594,0.741544,0.726854,heart_s0,train
5,0,0,0,3,2,0,6,0,0,heart_s0_487.png,Heart,0.851922,0.742041,0.710735,heart_s0,train
7,0,0,1,2,2,0,1,0,0,heart_s0_489.png,Heart,0.869257,0.768748,0.743576,heart_s0,train
8,0,0,5,3,2,0,7,0,0,heart_s0_491.png,Heart,0.860879,0.755739,0.596328,heart_s0,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
148989,0,0,0,0,2,45,0,0,0,cervix_s0_45806.png,Cervix,0.833179,0.714059,0.573759,cervix_s0,train
148990,0,0,1,4,10,23,0,0,2,cervix_s0_45815.png,Cervix,0.832595,0.713201,0.470911,cervix_s0,train
148993,0,0,0,0,0,34,0,0,0,cervix_s0_46004.png,Cervix,0.857879,0.751128,0.664169,cervix_s0,train
148994,1,0,0,0,1,63,0,0,1,cervix_s0_46214.png,Cervix,0.826384,0.704135,0.552902,cervix_s0,train


In [5]:
# Compute the pixel-wise frequency of each class
cell_type_frequencies = ds_train[cell_type_cols].sum() / ds_train[cell_type_cols].values.sum()

# Apply the logarithmic scaling to compute weights
cell_type_weights = np.log(1 + (1 / cell_type_frequencies))

# Set weights to zero for ignored categories
for cat in ignore_cat:
    if cat in cell_type_cols:
        cell_type_weights[cell_type_cols.index(cat)] = 0

# Normalize the weights (excluding the background for now)
cell_type_weights /= cell_type_weights.sum()

cell_type_weights

T_NK                        0.086009
B_Plasma                    0.100562
Myeloid                     0.090318
Blood_vessel                0.105798
Fibroblast_Myofibroblast    0.087633
Epithelial                  0.043725
Specialized                 0.193780
Melanocyte                  0.150250
Dead                        0.141926
dtype: float64

In [6]:
# Convert to a list for further processing
list_cell_type_weights = cell_type_weights.tolist()

# Add background weight at the beginning (minimum weight from normalized weights)
background_weight = min([w for w in list_cell_type_weights if w > 0])
list_cell_type_weights.insert(0, background_weight)

# Round the weights to the desired precision
list_cell_type_weights = [round(w, 3) for w in list_cell_type_weights]

print("WEIGHTS FOR FT LOSS USING LOGARITHMIC SCALING:")
print(list_cell_type_weights)

WEIGHTS FOR FT LOSS USING LOGARITHMIC SCALING:
[0.044, 0.086, 0.101, 0.09, 0.106, 0.088, 0.044, 0.194, 0.15, 0.142]


In [7]:
print("NUMBER OF CLASSES INCLUDING BACKGROUND:")
print(len(list_cell_type_weights)) # including the background

NUMBER OF CLASSES INCLUDING BACKGROUND:
10


In [8]:
print("CLASSES IN RIGHT ORDER WITH BACKGROUND FIRST:")
print(["Background"] + cell_type_cols)

CLASSES IN RIGHT ORDER WITH BACKGROUND FIRST:
['Background', 'T_NK', 'B_Plasma', 'Myeloid', 'Blood_vessel', 'Fibroblast_Myofibroblast', 'Epithelial', 'Specialized', 'Melanocyte', 'Dead']


In [9]:
print("COUNT OF PATCHES FOR EACH TISSUE TYPE (ALPHABETICAL ORDER):")
ds_train['type'].value_counts().sort_index()

COUNT OF PATCHES FOR EACH TISSUE TYPE (ALPHABETICAL ORDER):


type
Breast        43146
Cervix         7215
Colon           986
Heart          1444
Kidney         3965
Liver          7364
Lung           4068
LymphNode       617
Ovarian        5416
Pancreatic     2797
Prostate       1669
Skin           4905
Tonsil         5798
Name: count, dtype: int64

In [10]:
print("NUMBER OF TISSUES:")
print(len(ds_train['type'].unique()))

NUMBER OF TISSUES:
13


#### New option for weighting:

In [8]:
compute_class_weights(
    ds_train=ds_train,
    cell_type_cols=cell_type_cols,
    ignore_cat=["Dead"],
    fraction_unrelevant=0.1,    # fraction of the min relevant weight to assign to unrelevant classes
    fraction_background=1.0)

WEIGHTS FOR FT LOSS USING LOGARITHMIC SCALING:
['Background', 'T_NK', 'B_Plasma', 'Myeloid', 'Blood_vessel', 'Fibroblast_Myofibroblast', 'Epithelial', 'Specialized', 'Melanocyte', 'Dead']
[0.048, 0.095, 0.111, 0.1, 0.117, 0.097, 0.048, 0.214, 0.166, 0.005]


## ds_3

In [3]:
dataset_id = 'ds_3'
cell_cat_id = 'ct_3'
ignore_cat = ['Other'] # Define the list of cell types to ignore in the loss

In [None]:
with open(f"/Volumes/DD1_FGS/MICS/data_HE2CellType/CT_DS/annots/annot_dicts_{cell_cat_id}/cat2color.json", "r") as f:
    cat2color = json.load(f)

cell_type_cols = cell_type_cols = list(cat2color.keys()) # do not provide "Background"
cell_type_cols

['Immune', 'Stromal', 'Epithelial', 'Melanocyte', 'Other']

In [None]:
ds_infos = pd.read_csv(f'/Volumes/DD1_FGS/MICS/data_HE2CellType/HE2CT/training_datasets/{dataset_id}/informations/infos_{dataset_id}.csv')
ds_train = ds_infos[ds_infos['set'] == 'train']
ds_train

Unnamed: 0,Epithelial,Melanocyte,Immune,Stromal,Other,img,type,Dice,Jaccard,bPQ,slide_id,set
2,0,0,0,0,3,heart_s0_480.png,Heart,0.851594,0.741544,0.726854,heart_s0,train
3,0,0,0,4,5,heart_s0_483.png,Heart,0.843884,0.729930,0.542883,heart_s0,train
5,0,0,0,5,6,heart_s0_487.png,Heart,0.851922,0.742041,0.710735,heart_s0,train
6,0,0,1,2,2,heart_s0_488.png,Heart,0.864434,0.761237,0.671951,heart_s0,train
7,0,0,1,4,1,heart_s0_489.png,Heart,0.869257,0.768748,0.743576,heart_s0,train
...,...,...,...,...,...,...,...,...,...,...,...,...
148989,45,0,0,2,0,cervix_s0_45806.png,Cervix,0.833179,0.714059,0.573759,cervix_s0,train
148990,23,0,1,14,2,cervix_s0_45815.png,Cervix,0.832595,0.713201,0.470911,cervix_s0,train
148991,0,0,1,2,0,cervix_s0_45971.png,Cervix,0.828354,0.707000,0.532206,cervix_s0,train
148992,59,0,0,0,2,cervix_s0_45987.png,Cervix,0.824035,0.700731,0.460038,cervix_s0,train


In [7]:
print("NUMBER OF CLASSES INCLUDING BACKGROUND:")
print(len(["Background"] + cell_type_cols)) # including the background

NUMBER OF CLASSES INCLUDING BACKGROUND:
6


In [8]:
print("CLASSES IN RIGHT ORDER WITH BACKGROUND FIRST:")
print(["Background"] + cell_type_cols)

CLASSES IN RIGHT ORDER WITH BACKGROUND FIRST:
['Background', 'Immune', 'Stromal', 'Epithelial', 'Melanocyte', 'Other']


In [9]:
print("COUNT OF PATCHES FOR EACH TISSUE TYPE (ALPHABETICAL ORDER):")
ds_train['type'].value_counts().sort_index()

COUNT OF PATCHES FOR EACH TISSUE TYPE (ALPHABETICAL ORDER):


Breast        43146
Cervix         7215
Colon           986
Heart          1444
Kidney         3965
Liver          7364
Lung           4068
LymphNode       617
Ovarian        5416
Pancreatic     2797
Prostate       1669
Skin           4905
Tonsil         5798
Name: type, dtype: int64

In [10]:
print("NUMBER OF TISSUES:")
print(len(ds_train['type'].unique()))

NUMBER OF TISSUES:
13


#### New option for weighting:

In [11]:
compute_class_weights(
    ds_train=ds_train,
    cell_type_cols=cell_type_cols,
    ignore_cat=["Other"],
    fraction_unrelevant=0.1,    # fraction of the min relevant weight to assign to unrelevant classes
    fraction_background=1.0)

WEIGHTS FOR FT LOSS USING LOGARITHMIC SCALING:
['Background', 'Immune', 'Stromal', 'Epithelial', 'Melanocyte', 'Other']
[0.118, 0.151, 0.194, 0.118, 0.407, 0.012]


## ds_4

In [3]:
dataset_id = 'ds_4'
cell_cat_id = 'ct_3'
ignore_cat = ['Other'] # Define the list of cell types to ignore in the loss

In [None]:
with open(f"/Volumes/DD1_FGS/MICS/data_HE2CellType/CT_DS/annots/annot_dicts_{cell_cat_id}/cat2color.json", "r") as f:
    cat2color = json.load(f)

In [5]:
cell_type_cols = cell_type_cols = list(cat2color.keys()) # do not provide "Background"
cell_type_cols

['Immune', 'Stromal', 'Epithelial', 'Melanocyte', 'Other']

In [None]:
ds_infos = pd.read_csv(f'/Volumes/DD1_FGS/MICS/data_HE2CellType/HE2CT/training_datasets/{dataset_id}/informations/infos_{dataset_id}.csv')
ds_train = ds_infos[ds_infos['set'] == 'train']
ds_train

Unnamed: 0,Epithelial,Melanocyte,Immune,Stromal,Other,img,type,Dice,Jaccard,bPQ,slide_id,set
1,0,0,0,3,4,heart_s0_363.png,Heart,0.709650,0.549967,0.378748,heart_s0,train
3,0,0,0,4,0,heart_s0_365.png,Heart,0.833883,0.715094,0.647057,heart_s0,train
5,0,0,0,3,2,heart_s0_367.png,Heart,0.676033,0.510612,0.615884,heart_s0,train
6,0,0,0,1,0,heart_s0_437.png,Heart,0.920128,0.852071,0.852069,heart_s0,train
8,0,0,0,0,3,heart_s0_480.png,Heart,0.851594,0.741544,0.726854,heart_s0,train
...,...,...,...,...,...,...,...,...,...,...,...,...
508475,6,0,0,0,33,cervix_s0_46410.png,Cervix,0.711000,0.551590,0.358931,cervix_s0,train
508476,20,0,0,0,46,cervix_s0_46411.png,Cervix,0.726411,0.570366,0.332126,cervix_s0,train
508477,26,0,0,0,24,cervix_s0_46412.png,Cervix,0.678762,0.513731,0.348578,cervix_s0,train
508481,0,0,6,8,17,cervix_s0_46421.png,Cervix,0.730785,0.575777,0.446292,cervix_s0,train


In [7]:
print("NUMBER OF CLASSES INCLUDING BACKGROUND:")
print(len(["Background"] + cell_type_cols)) # including the background

NUMBER OF CLASSES INCLUDING BACKGROUND:
6


In [8]:
print("CLASSES IN RIGHT ORDER WITH BACKGROUND FIRST:")
print(["Background"] + cell_type_cols)

CLASSES IN RIGHT ORDER WITH BACKGROUND FIRST:
['Background', 'Immune', 'Stromal', 'Epithelial', 'Melanocyte', 'Other']


In [9]:
print("COUNT OF PATCHES FOR EACH TISSUE TYPE (ALPHABETICAL ORDER):")
ds_train['type'].value_counts().sort_index()

COUNT OF PATCHES FOR EACH TISSUE TYPE (ALPHABETICAL ORDER):


Breast        131871
Cervix         21296
Colon          12922
Heart           2888
Kidney          6606
Liver          17271
Lung           16568
LymphNode       7582
Ovarian        18478
Pancreatic     22332
Prostate        7355
Skin           16426
Tonsil         23506
Name: type, dtype: int64

In [11]:
print("NUMBER OF TISSUES:")
print(len(ds_train['type'].unique()))

NUMBER OF TISSUES:
13


In [10]:
compute_class_weights(
    ds_train=ds_train,
    cell_type_cols=cell_type_cols,
    ignore_cat=["Other"],
    fraction_unrelevant=0.1,    # fraction of the min relevant weight to assign to unrelevant classes
    fraction_background=1.0)

WEIGHTS FOR FT LOSS USING LOGARITHMIC SCALING:
['Background', 'Immune', 'Stromal', 'Epithelial', 'Melanocyte', 'Other']
[0.131, 0.131, 0.175, 0.137, 0.414, 0.013]
