In [1]:
!pip install --no-deps /kaggle/input/libauc-1-2-0/libauc-1.2.0-py3-none-any.whl


Processing /kaggle/input/libauc-1-2-0/libauc-1.2.0-py3-none-any.whl
libauc is already installed with the same version as the provided wheel. Use --force-reinstall to force an installation of the wheel.


In [2]:
from libauc.losses import AUCM_MultiLabel, CrossEntropyLoss
from libauc.optimizers import PESG, Adam
from libauc.models import densenet121 as DenseNet121
from libauc.datasets import CheXpert
from libauc.metrics import auc_roc_score # for multi-task

from PIL import Image
import numpy as np
import torch 
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torch.nn.functional as F   

import pandas as pd
import cv2

from torchvision.models import densenet121, DenseNet121_Weights
import torchvision.transforms as tfs

import torch.nn as nn

import time

In [3]:
fair_scaling_coef = 0.5
fair_scaling_sinkhorn_blur = .1
fair_scaling_temperature = 1.
fair_right_asymptote = 1.

labels = [
    'Enlarged Cardiomediastinum',
    'Cardiomegaly',
    'Lung Opacity',
    'Lung Lesion',
    'Edema',
    'Consolidation',
    'Pneumonia',
    'Atelectasis',
    'Pneumothorax',
    'Pleural Effusion',
    'Pleural Other',
    'Fracture',
    'Support Devices',
    'No Finding'
]

labels_small = [
    'Cardiomegaly',
    'Pneumonia',
    'Pleural Effusion',
   'Fracture',
    'No Finding'
]

labels_abbr_small = [
    'Cd',
    'Pa',
    'Ef',
    'Fr',
    'NF'
]

In [4]:
def df_age_disaggregation(df, left_edge = 15, right_edge = 105, bin_width = 5):

    # define age bins and labels
    bins = list(range(left_edge, right_edge, bin_width)) 
    labels = [f"{b}–{b+4}" for b in bins[:-1]]

    if right_edge>105:
        bins.append(120)
        labels.append('>=100')
    
    df["age_group"]= pd.cut(
        df["Age"],
        bins=bins,
        labels=labels,
        right=False)
    
    return df

def age_group_to_index(age_group):
  
    age_groups_order = [
        '15–19', '20–24', '25–29', '30–34', '35–39',
        '40–44', '45–49', '50–54', '55–59', '60–64',
        '65–69', '70–74', '75–79', '80–84', '85–89',
        '90–94', '95–99'
    ]

    # Build lookup table
    group_to_idx = {g: i for i, g in enumerate(age_groups_order)}

    # Convert input to python list if needed
    if isinstance(age_group, torch.Tensor):
        age_group = age_group.tolist()

    indices = []
    for g in age_group:
        if g not in group_to_idx:
            raise ValueError(f"Unknown age group: {g}")
        indices.append(group_to_idx[g])

    return torch.tensor(indices, dtype=torch.long)

# Compute group performance (offline / per epoch)

def compute_group_auc(y_true, y_prob, sensitive_attr):
    """
    y_true: (N, C)
    y_prob: (N, C)
    sensitive_attr: (N,)
    """
    group_aucs = {}
    groups = np.unique(sensitive_attr)

    for g in groups:
        mask = sensitive_attr == g
        if mask.sum() < 10:
            continue  # avoid unstable estimates
        group_aucs[g] = np.mean(auc_roc_score(
            y_true[mask], y_prob[mask]
        ))
    return group_aucs

In [5]:
pip install geomloss

Note: you may need to restart the kernel to use updated packages.


In [6]:
# FIS code
'''
    Reference: 
    
    @misc{Luo2024,
      title={FairVision: Equitable Deep Learning for Eye Disease Screening via Fair Identity Scaling}, 
      author={Yan Luo and Muhammad Osama Khan and Yu Tian and Min Shi and Zehao Dou and Tobias Elze and Yi Fang and Mengyu Wang},
      year={2024},
      eprint={2310.02492},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2310.02492}
    }
'''

# GeomLoss is a third-party research library. Purpose: Efficient Optimal Transport (OT) loss computations for ML
# Repository: https://github.com/jeanfeydy/geomloss
from geomloss import SamplesLoss

class Fair_Loss_Scaler(nn.Module):
    def __init__(self, level='individual', fair_scaling_group_weights=None, 
                    fair_scaling_temperature=1., fair_scaling_coef=.5, sinkhorn_blur=0.1, right_asymptote=2.):
        super().__init__()
        self.level = level
        self.fair_scaling_temperature = fair_scaling_temperature
        self.sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=sinkhorn_blur)
        self.fair_scaling_coef = fair_scaling_coef
        #self.fair_scaling_coef = nn.Parameter(torch.tensor(fair_scaling_coef))
        self.right_asymptote = right_asymptote
        
    def forward(self, x, smp_xs, attr):
                
        # x: (B, C) → reduce to (B,)
        if x.ndim == 2:
            x = x.mean(dim=1)

        # individual weights
        individual_weights = torch.softmax(x.detach() / self.fair_scaling_temperature, dim=0) * self.right_asymptote
    
        # total distribution
        # ttl_smp_x = torch.cat(smp_xs)
        ttl_smp_x = torch.stack(smp_xs)  # [num_groups]
    
        # distances_distributions = torch.zeros(len(smp_xs), device=x.device)
        distances_distributions = torch.zeros(len(smp_xs), device=x.device, dtype=x.dtype)
    
        # for i, smp_x in enumerate(smp_xs):
        #     distances_distributions[i] = self.sinkhorn_loss(
        #         smp_x.view(1, -1, 1),
        #         ttl_smp_x.view(1, -1, 1)
        #     )

        for i, smp_x in enumerate(smp_xs):
            distances_distributions[i] = self.sinkhorn_loss(
                smp_x.view(1, 1, 1),           # (B=1, N=1, D=1)
                ttl_smp_x.view(1, -1, 1)       # (B=1, M=G, D=1)
            )
    
        group_weights = torch.softmax(
            distances_distributions[attr.long()] / self.fair_scaling_temperature,
            dim=0
        ) * self.right_asymptote
    
        loss = (
            ((1 - self.fair_scaling_coef) * individual_weights
             + self.fair_scaling_coef * group_weights)
            * x
        ).mean()
    
        return loss


In [7]:
class CheXpert(Dataset):
    '''
    Reference: 
        @inproceedings{yuan2021robust,
            title={Large-scale Robust Deep AUC Maximization: A New Surrogate Loss and Empirical Studies on Medical Image Classification},
            author={Yuan, Zhuoning and Yan, Yan and Sonka, Milan and Yang, Tianbao},
            booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
            year={2021}
            }
    '''
    def __init__(self, 
                 csv_path, 
                 image_root_path='',
                 image_size=320,
                 class_index=0, 
                 seed=123,
                 # verbose=True,
                 train_cols=labels_small,
                 mode='train'):
        
    
        # load data from csv
        self.df = pd.read_csv(csv_path)
        self.df['Path'] = self.df['Path'].str.replace("\\", "/")
        self.df['Path'] = self.df['Path'].str.replace('CheXpert-v1.0-small/', '')
        self.df['Path'] = self.df['Path'].str.replace('CheXpert-v1.0/', '')
       
        self._num_images = len(self.df)
        
        assert class_index in [-1, 0, 1, 2, 3, 4], 'Out of selection!'
        assert image_root_path != '', 'You need to pass the correct location for the dataset!'

        if class_index == -1: # 5 classes
            print ('Multi-label mode: True, Number of classes: [%d]'%len(train_cols))
            self.select_cols = train_cols
            self.value_counts_dict = {}
            for class_key, select_col in enumerate(train_cols):
                class_value_counts_dict = self.df[select_col].value_counts().to_dict()
                self.value_counts_dict[class_key] = class_value_counts_dict
        else:       # 1 class
            self.select_cols = [train_cols[class_index]]  # this var determines the number of classes
            self.value_counts_dict = self.df[self.select_cols[0]].value_counts().to_dict()
        
        self.mode = mode
        self.class_index = class_index
        self.image_size = image_size
        
        self._images_list =  [image_root_path+path for path in self.df['Path'].tolist()]
        if class_index != -1:
            self._labels_list = self.df[train_cols].values[:, class_index].tolist()
        else:
            self._labels_list = self.df[train_cols].values.tolist()

        sex_map = {"Male": 1, "Female": 0}
        self.sex = self.df["Sex"].map(sex_map).values

        self.sex_str = self.df["Sex"].values

        self.age = self.df["Age"].values

        data = self.df["Age"].values
        df = df_age_disaggregation(pd.DataFrame(data, columns=['Age']))
        self.age_group = df['age_group']

        
        self.age_bin = age_group_to_index(self.age_group)
        assert self.age_bin.min() >= 0 and self.age_bin.max() <= 16
 
    @property        
    def class_counts(self):
        return self.value_counts_dict
    
    @property
    def num_classes(self):
        return len(self.select_cols)
       
    @property  
    def data_size(self):
        return self._num_images 
    
    def image_augmentation(self, image):
        img_aug = tfs.Compose([tfs.RandomAffine(degrees=(-15, 15), translate=(0.05, 0.05), scale=(0.95, 1.05), fill=128)]) # pytorch 3.7: fillcolor --> fill
        image = img_aug(image)
        return image
    
    def __len__(self):
        return self._num_images
    
    def __getitem__(self, idx):

        image = cv2.imread(self._images_list[idx], 0)
        image = Image.fromarray(image)
        if self.mode == 'train':
            image = self.image_augmentation(image)
        image = np.array(image)
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        
        # resize and normalize; e.g., ToTensor()
        image = cv2.resize(image, dsize=(self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)  
        image = image/255.0
        __mean__ = np.array([[[0.485, 0.456, 0.406]]])
        __std__ =  np.array([[[0.229, 0.224, 0.225]  ]]) 
        image = (image-__mean__)/__std__
        image = image.transpose((2, 0, 1)).astype(np.float32)
        if self.class_index != -1: # multi-class mode
            label = np.array(self._labels_list[idx]).reshape(-1).astype(np.float32)
        else:
            label = np.array(self._labels_list[idx]).reshape(-1).astype(np.float32)

        attributes = {
            "Sex": self.sex[idx],  
            "Age": self.age[idx],
            "AgeBin": self.age_bin[idx]
        }
                
        return image, label, attributes



if __name__ == '__main__':

    root = '/kaggle/input/chexpert-v1-0/CheXpert-v1.0/'
    # root = "D:/chexpertchestxrays/CheXpert-v1.0/CheXpert-v1.0/"
    traindSet = CheXpert(csv_path=root+'train.csv', image_root_path=root, image_size=320, mode='train', class_index=0)
    testSet =  CheXpert(csv_path=root+'valid.csv',  image_root_path=root, image_size=320, mode='valid', class_index=0)

    trainloader =  torch.utils.data.DataLoader(traindSet, batch_size=32, num_workers=2, drop_last=True, shuffle=True)
    testloader =  torch.utils.data.DataLoader(testSet, batch_size=32, num_workers=2, drop_last=False, shuffle=False)


In [8]:
def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

root = '/kaggle/input/chexpert-v1-0/CheXpert-v1.0/'
traindSet = CheXpert(csv_path=root+'train.csv', image_root_path=root, image_size=224, mode='train', class_index=-1)
testSet =  CheXpert(csv_path=root+'valid.csv',  image_root_path=root, image_size=224, mode='valid', class_index=-1)
trainloader =  torch.utils.data.DataLoader(traindSet, batch_size=32, num_workers=0, shuffle=True, pin_memory=False)
testloader =  torch.utils.data.DataLoader(testSet, batch_size=32, num_workers=0, shuffle=False, pin_memory=False)

# paramaters
SEED = 123
BATCH_SIZE = 32
lr = 0.1 
epoch_decay = 2e-3
# epoch_decay = 0.95 
weight_decay = 1e-5
margin = 1.0
total_epochs = 2

# model
set_all_seeds(SEED)
# model = DenseNet121(pretrained=True, last_activation=None, activations='relu', num_classes=5)
model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
model.classifier = nn.Linear(model.classifier.in_features, 5) 
# model = model.cuda()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# define loss & optimizer
loss_fn = AUCM_MultiLabel(num_classes=5)
optimizer = PESG(model, 
                 loss_fn=loss_fn,
                 lr=lr, 
                 margin=margin, 
                 epoch_decay=epoch_decay, 
                 weight_decay=weight_decay)

Multi-label mode: True, Number of classes: [5]
Multi-label mode: True, Number of classes: [5]


In [9]:
# training densenet121 without fairness consideration
print ('Start Training')
print ('-'*30)

best_val_auc = 0 
# Measure training time
start = time.time()
for epoch in range(total_epochs):
    if epoch > 0:
        optimizer.update_regularizer(decay_factor=10)    

    for idx, data in enumerate(trainloader):
      # train_data, train_labels = data
      train_data, train_labels, train_attributes = data  
      train_data, train_labels  = train_data.cuda(), train_labels.cuda()
      y_pred = model(train_data)
      y_pred = torch.sigmoid(y_pred)
      loss = loss_fn(y_pred, train_labels)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
              
      # validation  
      if idx % 400 == 0:
         model.eval()
         with torch.no_grad():    
              test_pred = []
              test_true = [] 
              test_attributes_true = [] 
              for jdx, data in enumerate(testloader):
                  # test_data, test_labels = data
                  test_data, test_labels, test_attributes = data
                  test_data = test_data.cuda()
                  y_pred = model(test_data)
                  y_pred = torch.sigmoid(y_pred)
                  test_pred.append(y_pred.cpu().detach().numpy())
                  test_true.append(test_labels.numpy())
                  test_attributes_true.append(test_attributes["Sex"].cpu().numpy())
            
              test_true = np.concatenate(test_true)
              test_pred = np.concatenate(test_pred)
              test_attributes_true = np.concatenate(test_attributes_true)

              group_aucs = compute_group_auc(test_true, test_pred, test_attributes_true)
          
              val_auc_mean = np.mean(auc_roc_score(test_true, test_pred)) 
              model.train()

              if best_val_auc < val_auc_mean:
                 best_val_auc = val_auc_mean
                 torch.save(model.state_dict(), 'aucm_pretrained_model.pth')

              print ('Epoch=%s, BatchID=%s, Val_AUC=%.4f, Best_Val_AUC=%.4f'%(epoch, idx, val_auc_mean, best_val_auc))
              print("Group AUCs:", group_aucs)

end = time.time()
training_time = (end - start)
print(f"Training time per epoch: {training_time:.2f} sec")

Start Training
------------------------------
Epoch=0, BatchID=0, Val_AUC=0.4767, Best_Val_AUC=0.4767
Group AUCs: {0: 0.47246201704067214, 1: 0.4797506233163357}
Epoch=0, BatchID=400, Val_AUC=0.6713, Best_Val_AUC=0.6713
Group AUCs: {0: 0.6825045390425128, 1: 0.6483459650141441}
Epoch=0, BatchID=800, Val_AUC=0.6758, Best_Val_AUC=0.6758
Group AUCs: {0: 0.6796136120247287, 1: 0.6698635895269837}
Epoch=0, BatchID=1200, Val_AUC=0.6948, Best_Val_AUC=0.6948
Group AUCs: {0: 0.687710952443678, 1: 0.6953653248016336}
Epoch=0, BatchID=1600, Val_AUC=0.6773, Best_Val_AUC=0.6948
Group AUCs: {0: 0.6530487848955475, 1: 0.6946249128746447}
Epoch=0, BatchID=2000, Val_AUC=0.6840, Best_Val_AUC=0.6948
Group AUCs: {0: 0.6633151817405532, 1: 0.6960110198071174}
Epoch=0, BatchID=2400, Val_AUC=0.6876, Best_Val_AUC=0.6948
Group AUCs: {0: 0.6811634111249051, 1: 0.6882507826998794}
Epoch=0, BatchID=2800, Val_AUC=0.6999, Best_Val_AUC=0.6999
Group AUCs: {0: 0.692915367103084, 1: 0.6982131634994435}
Epoch=0, BatchID

In [10]:
# show auc roc scores for each task 
auc_roc_score(test_true, test_pred)

[0.7568214032600993,
 0.8982300884955753,
 0.913218339440522,
 0.0,
 0.8888292158968851]

In [11]:
df_test = pd.DataFrame({
    "age_group": testSet.age_group,
    "Sex": testSet.sex_str,
    "Age": testSet.age
})

test_true_df = pd.DataFrame(
    test_true,
    columns=labels_small
)

test_pred_df = pd.DataFrame(
    test_pred,
    columns=labels_abbr_small
)

df_test = pd.concat([df_test.reset_index(drop=True),
                     test_true_df.reset_index(drop=True),
                    test_pred_df.reset_index(drop=True)],
                    axis=1)


df_test = df_test.sort_values(['age_group', 'Sex'])

print(df_test)

    age_group     Sex  Age  Cardiomegaly  Pneumonia  Pleural Effusion  \
17      15–19  Female   19           0.0        0.0               0.0   
66      15–19  Female   18           0.0        0.0               0.0   
67      15–19  Female   18           0.0        0.0               0.0   
172     15–19    Male   19           0.0        0.0               0.0   
68      20–24  Female   23           0.0        0.0               0.0   
..        ...     ...  ...           ...        ...               ...   
104     90–94    Male   90           0.0        0.0               0.0   
126     90–94    Male   90           1.0        0.0               1.0   
201     90–94    Male   90           0.0        0.0               0.0   
210     90–94    Male   90           1.0        0.0               0.0   
233     90–94    Male   90           0.0        0.0               1.0   

     Fracture  No Finding        Cd        Pa        Ef        Fr        NF  
17        0.0         1.0  0.004460  0.285352

In [12]:
df_test.to_csv("/kaggle/working/" + "results_test_pred.csv", index=False)

In [13]:
# # training densenet121 with Original FIS approach for Sex attribute

# from torch.utils.data import Subset

# def endless_loader(dataloader):
#     while True:
#         for batch in dataloader:
#             yield batch

# sex_tensor = torch.tensor(traindSet.sex)  # shape (N,)

# # print(sex_tensor.shape)            # (N,)
# # print(sex_tensor.unique())  

# female_indices = (sex_tensor == 0).nonzero(as_tuple=True)[0]
# male_indices   = (sex_tensor == 1).nonzero(as_tuple=True)[0]

# # print(len(female_indices), len(male_indices))
# # print(len(female_indices) + len(male_indices) == len(sex_tensor))

# female_dataset = Subset(traindSet, female_indices.tolist())
# male_dataset   = Subset(traindSet, male_indices.tolist())

# # x, y, attr = female_dataset[0]
# # print(attr["Sex"])  #

# group_dataloaders = []

# group_dataset_loader = torch.utils.data.DataLoader(female_dataset, batch_size=32, num_workers=0, shuffle=True, drop_last=True, pin_memory=False)
# group_dataloaders.append(endless_loader(group_dataset_loader))

# group_dataset_loader = torch.utils.data.DataLoader(male_dataset, batch_size=32, num_workers=0, shuffle=True, drop_last=True, pin_memory=False)
# group_dataloaders.append(endless_loader(group_dataset_loader))

In [14]:
# # training densenet121 with Original FIS approach for Sex attribute

# from itertools import cycle

# group_iters = [cycle(dl) for dl in group_dataloaders]

# scaler = torch.amp.GradScaler('cuda')

# criterion = nn.BCEWithLogitsLoss(reduction='none')

# loss_scaler = Fair_Loss_Scaler(fair_scaling_coef=fair_scaling_coef, sinkhorn_blur=fair_scaling_sinkhorn_blur,
#                                         fair_scaling_temperature=fair_scaling_temperature, right_asymptote=fair_right_asymptote)



# print ('Start Training')
# print ('-'*30)

# # fis_loss = FISLoss(loss_fn, alpha=0.6)

# # train_attributes_all_sex = torch.tensor(traindSet.sex)
# # group_weights = compute_fis_group_weights(train_attributes_all_sex)

# best_val_auc = 0 
# for epoch in range(total_epochs):
#     if epoch > 0:
#         optimizer.update_regularizer(decay_factor=10)    

#     for idx, data in enumerate(trainloader):
        
#         with torch.amp.autocast(device_type='cuda'):
            
#             # train_data, train_labels = data
#             train_data, train_labels, train_attributes = data
#             train_data, train_labels  = train_data.cuda(), train_labels.cuda()
        
#             y_pred = model(train_data)
            
#             smp_losses = []
#             # for x in group_dataloaders:
#             for x in group_iters:
#                 smp_input, smp_target, smp_attr = next(x)
#                 smp_input = smp_input.to(device)
#                 smp_target = smp_target.to(device)
#                 with torch.no_grad():
#                     smp_pred = model(smp_input)
#                     if smp_pred.shape[1] == 1:
#                         smp_pred = smp_pred.squeeze(1)
#                         smp_loss_raw = criterion(smp_pred, smp_target)
#                         smp_loss = smp_loss_raw.mean(dim=1)
#                     elif smp_pred.shape[1] > 1:
#                         smp_loss_raw = criterion(smp_pred, smp_target.float())
#                         # smp_loss = criterion(smp_pred, smp_target.long())
#                         smp_loss = smp_loss_raw.mean(dim=1)
#                         smp_losses.append(smp_loss)
    
#             if y_pred.shape[1] == 1:
#                 y_pred = y_pred.squeeze(1)
#                 loss = criterion(y_pred, train_labels)
#                 pred_prob = torch.sigmoid(y_pred.detach())
#                 loss_per_sample = loss.mean(dim=1)  # (B,)
#             elif y_pred.shape[1] > 1:
#                 # loss = criterion(y_pred, train_labels.long())
#                 loss = criterion(y_pred, train_labels.float())
#                 # pred_prob = F.softmax(y_pred.detach(), dim=1)
#                 pred_prob = torch.sigmoid(y_pred.detach())
#                 loss_per_sample = loss.mean(dim=1)  # (B,)
                    
            
#             # print("main loss:", loss.shape)              # (B,) or scalar
#             # print("group loss:", smp_losses[0].shape)    # (B,)
#             # print("Sex attr:", smp_attr["Sex"].shape)
#             # print(loss_per_sample.shape)        # torch.Size([32])
#             # print(smp_losses[0].shape)          # torch.Size([32])
#             # print(train_attributes["Sex"].shape)#
#             loss = loss_scaler(loss_per_sample, smp_losses, attr=train_attributes["Sex"])

    
#         scaler.scale(loss).backward()
#         scaler.step(optimizer)
#         scaler.update()
        
#         optimizer.zero_grad()

              
#         # validation  
#         if idx % 400 == 0:
#             model.eval()
#             with torch.no_grad():
#                 test_pred = []
#                 test_true = [] 
#                 test_attributes_true = [] 
#                 for jdx, data in enumerate(testloader):
#                     # test_data, test_labels = data
#                     test_data, test_labels, test_attributes = data
#                     test_data = test_data.cuda()
#                     y_pred = model(test_data)
#                     y_pred = torch.sigmoid(y_pred)
#                     test_pred.append(y_pred.cpu().detach().numpy())
#                     test_true.append(test_labels.numpy())
#                     test_attributes_true.append(test_attributes["Sex"].cpu().numpy())
                
#                 test_true = np.concatenate(test_true)
#                 test_pred = np.concatenate(test_pred)
#                 test_attributes_true = np.concatenate(test_attributes_true)
    
#                 group_aucs = compute_group_auc(test_true, test_pred, test_attributes_true)
                           
#                 val_auc_mean = np.mean(auc_roc_score(test_true, test_pred)) 
#                 model.train()
    
#                 if best_val_auc < val_auc_mean:
#                     best_val_auc = val_auc_mean
#                     torch.save(model.state_dict(), 'aucm_pretrained_model_FIS_origin.pth')
    
#                 print ('Epoch=%s, BatchID=%s, Val_AUC=%.4f, Best_Val_AUC=%.4f'%(epoch, idx, val_auc_mean, best_val_auc))
#                 print("Group AUCs:", group_aucs)
              

In [17]:
# training densenet121 with Original FIS approach for Age attribute

scaler = torch.amp.GradScaler('cuda')

criterion = nn.BCEWithLogitsLoss(reduction='none')

loss_scaler = Fair_Loss_Scaler(fair_scaling_coef=fair_scaling_coef, sinkhorn_blur=fair_scaling_sinkhorn_blur,
                                        fair_scaling_temperature=fair_scaling_temperature, right_asymptote=fair_right_asymptote)

print ('Start Training')
print ('-'*30)

best_val_auc = 0 
# Measure training time
start = time.time()
for epoch in range(total_epochs):
    if epoch > 0:
        optimizer.update_regularizer(decay_factor=10)    

    for idx, data in enumerate(trainloader):
        
        with torch.amp.autocast(device_type='cuda'):
            
            # train_data, train_labels = data
            train_data, train_labels, train_attributes = data
            train_data, train_labels  = train_data.cuda(), train_labels.cuda() 
        
            y_pred = model(train_data)
            
    
            # Compute per-group loss inside the batch
            smp_losses = []
            smp_pred = model(train_data)
            smp_loss_raw = criterion(smp_pred, train_labels.float())
        
            for g in range(17):  
                mask = (train_attributes["AgeBin"] == g)
                
                if mask.any():
                    
                    smp_loss = smp_loss_raw[mask].mean()
                    smp_losses.append(smp_loss)
                else:
                    smp_loss = smp_loss_raw.new_tensor(0.0)
                    smp_losses.append(smp_loss)
                    
            if y_pred.shape[1] == 1:
                y_pred = y_pred.squeeze(1)
                loss = criterion(y_pred, train_labels)
                pred_prob = torch.sigmoid(y_pred.detach())
                loss_per_sample = loss.mean(dim=1)  # (B,)
            elif y_pred.shape[1] > 1:
                # loss = criterion(y_pred, train_labels.long())
                loss = criterion(y_pred, train_labels.float())
                # pred_prob = F.softmax(y_pred.detach(), dim=1)
                pred_prob = torch.sigmoid(y_pred.detach())
                loss_per_sample = loss.mean(dim=1)  # (B,)
                    
            # print(loss_per_sample.shape)        # torch.Size([32])
            # print(smp_losses[0].shape)          # torch.Size([32])
            # print(train_attributes["AgeBin"].shape)#
            
            loss = loss_scaler(loss_per_sample, smp_losses, attr=train_attributes["AgeBin"])

    
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        optimizer.zero_grad()

              
        # validation  
        if idx % 400 == 0:
            model.eval()
            with torch.no_grad():
                test_pred = []
                test_true = [] 
                test_attributes_true = [] 
                for jdx, data in enumerate(testloader):
                    # test_data, test_labels = data
                    test_data, test_labels, test_attributes = data
                    test_data = test_data.cuda()
                    y_pred = model(test_data)
                    y_pred = torch.sigmoid(y_pred)
                    test_pred.append(y_pred.cpu().detach().numpy())
                    test_true.append(test_labels.numpy())
                    test_attributes_true.append(test_attributes["AgeBin"].cpu().numpy())
                
                test_true = np.concatenate(test_true)
                test_pred = np.concatenate(test_pred)
                test_attributes_true = np.concatenate(test_attributes_true)
    
                group_aucs = compute_group_auc(test_true, test_pred, test_attributes_true)
                group_aucs_mean = np.mean(list(group_aucs.values()))
                           
                val_auc_mean = np.mean(auc_roc_score(test_true, test_pred)) 
                model.train()
    
                if best_val_auc < val_auc_mean:
                    best_val_auc = val_auc_mean
                    torch.save(model.state_dict(), 'aucm_pretrained_model_FIS_origin_age.pth')
    
                print ('Epoch=%s, BatchID=%s, Val_AUC=%.4f, Best_Val_AUC=%.4f'%(epoch, idx, val_auc_mean, best_val_auc))
                print("Mean Group AUCs:", group_aucs_mean)

end = time.time()
training_time = (end - start)
print(f"Training time per epoch: {training_time:.2f} sec")

Start Training
------------------------------
Epoch=0, BatchID=0, Val_AUC=0.6841, Best_Val_AUC=0.6841
Mean Group AUCs: 0.5347079594826736
Epoch=0, BatchID=400, Val_AUC=0.6890, Best_Val_AUC=0.6890
Mean Group AUCs: 0.5332128215288399
Epoch=0, BatchID=800, Val_AUC=0.6855, Best_Val_AUC=0.6890
Mean Group AUCs: 0.5291097606055273
Epoch=0, BatchID=1200, Val_AUC=0.6834, Best_Val_AUC=0.6890
Mean Group AUCs: 0.5258867444912069
Epoch=0, BatchID=1600, Val_AUC=0.6808, Best_Val_AUC=0.6890
Mean Group AUCs: 0.5244060471184041
Epoch=0, BatchID=2000, Val_AUC=0.6800, Best_Val_AUC=0.6890
Mean Group AUCs: 0.5231153566240524
Epoch=0, BatchID=2400, Val_AUC=0.6795, Best_Val_AUC=0.6890
Mean Group AUCs: 0.523349594220933
Epoch=0, BatchID=2800, Val_AUC=0.6769, Best_Val_AUC=0.6890
Mean Group AUCs: 0.5203518866145296
Epoch=0, BatchID=3200, Val_AUC=0.6803, Best_Val_AUC=0.6890
Mean Group AUCs: 0.522598417861061
Epoch=0, BatchID=3600, Val_AUC=0.6831, Best_Val_AUC=0.6890
Mean Group AUCs: 0.523688312075669
Epoch=0, Bat

In [18]:
# # show auc roc scores for each task 
auc_roc_score(test_true, test_pred)

[0.7226257973068745,
 0.9275442477876106,
 0.8907855929931182,
 0.0,
 0.8500268528464018]

In [19]:
df_test = pd.DataFrame({
    "age_group": testSet.age_group,
    "Sex": testSet.sex_str,
    "Age": testSet.age
})

test_true_df = pd.DataFrame(
    test_true,
    columns=labels_small
)

test_pred_df = pd.DataFrame(
    test_pred,
    columns=labels_abbr_small
)

df_test = pd.concat([df_test.reset_index(drop=True),
                     test_true_df.reset_index(drop=True),
                    test_pred_df.reset_index(drop=True)],
                    axis=1)


df_test = df_test.sort_values(['age_group', 'Sex'])

print(df_test)

    age_group     Sex  Age  Cardiomegaly  Pneumonia  Pleural Effusion  \
17      15–19  Female   19           0.0        0.0               0.0   
66      15–19  Female   18           0.0        0.0               0.0   
67      15–19  Female   18           0.0        0.0               0.0   
172     15–19    Male   19           0.0        0.0               0.0   
68      20–24  Female   23           0.0        0.0               0.0   
..        ...     ...  ...           ...        ...               ...   
104     90–94    Male   90           0.0        0.0               0.0   
126     90–94    Male   90           1.0        0.0               1.0   
201     90–94    Male   90           0.0        0.0               0.0   
210     90–94    Male   90           1.0        0.0               0.0   
233     90–94    Male   90           0.0        0.0               1.0   

     Fracture  No Finding        Cd        Pa        Ef        Fr        NF  
17        0.0         1.0  0.051299  0.236260

In [20]:
df_test.to_csv("/kaggle/working/" + "FIS_origin_results_test_pred_age.csv", index=False)

In [21]:
# FIS code for sex groups

class Fair_Loss_Scaler_sex(nn.Module):
    def __init__(self, level='individual', fair_scaling_group_weights=None, 
                    fair_scaling_temperature=1., fair_scaling_coef=.5, sinkhorn_blur=0.1, right_asymptote=2.):
        super().__init__()
        self.level = level
        self.fair_scaling_temperature = fair_scaling_temperature
        self.sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=sinkhorn_blur)
        self.fair_scaling_coef = fair_scaling_coef
        #self.fair_scaling_coef = nn.Parameter(torch.tensor(fair_scaling_coef))
        self.right_asymptote = right_asymptote
        
    def forward(self, x, smp_xs, attr):
                
        # x: (B, C) → reduce to (B,)
        if x.ndim == 2:
            x = x.mean(dim=1)

        # individual weights
        individual_weights = torch.softmax(x.detach() / self.fair_scaling_temperature, dim=0) * self.right_asymptote
    
        # total distribution
        ttl_smp_x = torch.cat(smp_xs)
        # ttl_smp_x = torch.stack(smp_xs)  # [num_groups]
    
        distances_distributions = torch.zeros(len(smp_xs), device=x.device)
        # distances_distributions = torch.zeros(len(smp_xs), device=x.device, dtype=x.dtype)
    
        for i, smp_x in enumerate(smp_xs):
            distances_distributions[i] = self.sinkhorn_loss(
                smp_x.view(1, -1, 1),
                ttl_smp_x.view(1, -1, 1)
            )

        # for i, smp_x in enumerate(smp_xs):
        #     distances_distributions[i] = self.sinkhorn_loss(
        #         smp_x.view(1, 1, 1),           # (B=1, N=1, D=1)
        #         ttl_smp_x.view(1, -1, 1)       # (B=1, M=G, D=1)
        #     )
    
        group_weights = torch.softmax(
            distances_distributions[attr.long()] / self.fair_scaling_temperature,
            dim=0
        ) * self.right_asymptote
    
        loss = (
            ((1 - self.fair_scaling_coef) * individual_weights
             + self.fair_scaling_coef * group_weights)
            * x
        ).mean()
    
        return loss


In [22]:
# training densenet121 with Original FIS approach for Sex attribute

scaler = torch.amp.GradScaler('cuda')

criterion = nn.BCEWithLogitsLoss(reduction='none')

loss_scaler = Fair_Loss_Scaler_sex(fair_scaling_coef=fair_scaling_coef, sinkhorn_blur=fair_scaling_sinkhorn_blur,
                                        fair_scaling_temperature=fair_scaling_temperature, right_asymptote=fair_right_asymptote)



print ('Start Training')
print ('-'*30)

best_val_auc = 0 
# Measure training time
start = time.time()
for epoch in range(total_epochs):
    if epoch > 0:
        optimizer.update_regularizer(decay_factor=10)    

    for idx, data in enumerate(trainloader):
        
        with torch.amp.autocast(device_type='cuda'):
            
            # train_data, train_labels = data
            train_data, train_labels, train_attributes = data
            train_data, train_labels  = train_data.cuda(), train_labels.cuda()
        
            y_pred = model(train_data)
            
   
            # Compute per-group loss inside the batch
            smp_losses = []
            smp_pred = model(train_data)
            smp_loss_raw = criterion(smp_pred, train_labels.float())
        
            for g in [0, 1]:  # example for Sex
                mask = (train_attributes["Sex"] == g)
                
                if mask.any():
                    # smp_loss_raw = criterion(smp_pred, smp_target.float())
                    smp_loss = smp_loss_raw[mask].mean(dim=1)
                    smp_losses.append(smp_loss)
                    
            if y_pred.shape[1] == 1:
                y_pred = y_pred.squeeze(1)
                loss = criterion(y_pred, train_labels)
                pred_prob = torch.sigmoid(y_pred.detach())
                loss_per_sample = loss.mean(dim=1)  # (B,)
            elif y_pred.shape[1] > 1:
                # loss = criterion(y_pred, train_labels.long())
                loss = criterion(y_pred, train_labels.float())
                # pred_prob = F.softmax(y_pred.detach(), dim=1)
                pred_prob = torch.sigmoid(y_pred.detach())
                loss_per_sample = loss.mean(dim=1)  # (B,)
                    
            
            loss = loss_scaler(loss_per_sample, smp_losses, attr=train_attributes["Sex"])

    
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        optimizer.zero_grad()

              
        # validation  
        if idx % 400 == 0:
            model.eval()
            with torch.no_grad():
                test_pred = []
                test_true = [] 
                test_attributes_true = [] 
                for jdx, data in enumerate(testloader):
                    # test_data, test_labels = data
                    test_data, test_labels, test_attributes = data
                    test_data = test_data.cuda()
                    y_pred = model(test_data)
                    y_pred = torch.sigmoid(y_pred)
                    test_pred.append(y_pred.cpu().detach().numpy())
                    test_true.append(test_labels.numpy())
                    test_attributes_true.append(test_attributes["Sex"].cpu().numpy())
                
                test_true = np.concatenate(test_true)
                test_pred = np.concatenate(test_pred)
                test_attributes_true = np.concatenate(test_attributes_true)
    
                group_aucs = compute_group_auc(test_true, test_pred, test_attributes_true)
                           
                val_auc_mean = np.mean(auc_roc_score(test_true, test_pred)) 
                model.train()
    
                if best_val_auc < val_auc_mean:
                    best_val_auc = val_auc_mean
                    torch.save(model.state_dict(), 'aucm_pretrained_model_FIS_origin.pth')
    
                print ('Epoch=%s, BatchID=%s, Val_AUC=%.4f, Best_Val_AUC=%.4f'%(epoch, idx, val_auc_mean, best_val_auc))
                print("Group AUCs:", group_aucs)

end = time.time()
training_time = (end - start)
print(f"Training time per epoch: {training_time:.2f} sec")

Start Training
------------------------------
Epoch=0, BatchID=0, Val_AUC=0.6787, Best_Val_AUC=0.6787
Group AUCs: {0: 0.6633381160731556, 1: 0.6919479209166528}
Epoch=0, BatchID=400, Val_AUC=0.6822, Best_Val_AUC=0.6822
Group AUCs: {0: 0.6646649213701095, 1: 0.6967254555407348}
Epoch=0, BatchID=800, Val_AUC=0.6836, Best_Val_AUC=0.6836
Group AUCs: {0: 0.6646373465321491, 1: 0.7001289827498474}
Epoch=0, BatchID=1200, Val_AUC=0.6794, Best_Val_AUC=0.6836
Group AUCs: {0: 0.6626063355197431, 1: 0.6931282164567372}
Epoch=0, BatchID=1600, Val_AUC=0.6820, Best_Val_AUC=0.6836
Group AUCs: {0: 0.6651267224313921, 1: 0.6965455251885827}
Epoch=0, BatchID=2000, Val_AUC=0.6812, Best_Val_AUC=0.6836
Group AUCs: {0: 0.6640515195607989, 1: 0.6949609952596505}
Epoch=0, BatchID=2400, Val_AUC=0.6808, Best_Val_AUC=0.6836
Group AUCs: {0: 0.6643219469473924, 1: 0.6945355606909935}
Epoch=0, BatchID=2800, Val_AUC=0.6835, Best_Val_AUC=0.6836
Group AUCs: {0: 0.6639420486950455, 1: 0.6999904113001217}
Epoch=0, BatchI

In [23]:
# # show auc roc scores for each task 
auc_roc_score(test_true, test_pred)

[0.7346739900779589,
 0.9231194690265487,
 0.8974886048797927,
 0.0,
 0.860499462943072]

In [24]:
df_test = pd.DataFrame({
    "age_group": testSet.age_group,
    "Sex": testSet.sex_str,
    "Age": testSet.age
})

test_true_df = pd.DataFrame(
    test_true,
    columns=labels_small
)

test_pred_df = pd.DataFrame(
    test_pred,
    columns=labels_abbr_small
)

df_test = pd.concat([df_test.reset_index(drop=True),
                     test_true_df.reset_index(drop=True),
                    test_pred_df.reset_index(drop=True)],
                    axis=1)


df_test = df_test.sort_values(['age_group', 'Sex'])

print(df_test)

    age_group     Sex  Age  Cardiomegaly  Pneumonia  Pleural Effusion  \
17      15–19  Female   19           0.0        0.0               0.0   
66      15–19  Female   18           0.0        0.0               0.0   
67      15–19  Female   18           0.0        0.0               0.0   
172     15–19    Male   19           0.0        0.0               0.0   
68      20–24  Female   23           0.0        0.0               0.0   
..        ...     ...  ...           ...        ...               ...   
104     90–94    Male   90           0.0        0.0               0.0   
126     90–94    Male   90           1.0        0.0               1.0   
201     90–94    Male   90           0.0        0.0               0.0   
210     90–94    Male   90           1.0        0.0               0.0   
233     90–94    Male   90           0.0        0.0               1.0   

     Fracture  No Finding        Cd        Pa        Ef        Fr        NF  
17        0.0         1.0  0.076875  0.253473

In [26]:
df_test.to_csv("/kaggle/working/" + "FIS_origin_results_test_pred.csv", index=False)